PyTorch torch.nn.ConvTranspose2d 函数
torch.nn.ConvTranspose2d 是 PyTorch 中的二维转置卷积,也称为反卷积或上采样卷积。
它用于上采样特征图,是生成网络和分割网络的关键组件。
函数定义
torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
参数
output_padding: 额外输出填充
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 转置卷积:上采样
deconv = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
x = torch.randn(1, 64, 16, 16)
output = deconv(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
# 转置卷积:上采样
deconv = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
x = torch.randn(1, 64, 16, 16)
output = deconv(x)
print("输入:", x.shape, "-> 输出:", output.shape)
示例 2: 生成网络
实例
import torch
import torch.nn as nn
# 简化的 DCGAN 生成器
class Generator(nn.Module):
def __init__(self, latent_dim=100):
super(Generator, self).__init__()
self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, 2, 1), # 4->8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1), # 8->16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 16->32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, 2, 1), # 32->64
nn.Tanh()
)
def forward(self, x):
x = self.fc(x).view(-1, 512, 4, 4)
return self.deconv(x)
gen = Generator()
z = torch.randn(1, 100)
img = gen(z)
print("噪声:", z.shape, "-> 图像:", img.shape)
import torch.nn as nn
# 简化的 DCGAN 生成器
class Generator(nn.Module):
def __init__(self, latent_dim=100):
super(Generator, self).__init__()
self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, 2, 1), # 4->8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1), # 8->16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 16->32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, 2, 1), # 32->64
nn.Tanh()
)
def forward(self, x):
x = self.fc(x).view(-1, 512, 4, 4)
return self.deconv(x)
gen = Generator()
z = torch.randn(1, 100)
img = gen(z)
print("噪声:", z.shape, "-> 图像:", img.shape)
示例 3: 分割网络上采样
实例
import torch
import torch.nn as nn
# U-Net 解码器部分
decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU()
)
x = torch.randn(1, 256, 8, 8)
output = decoder(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
# U-Net 解码器部分
decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU()
)
x = torch.randn(1, 256, 8, 8)
output = decoder(x)
print("输入:", x.shape, "-> 输出:", output.shape)
示例 4: stride=2 的输出计算
实例
import torch
import torch.nn as nn
# 不同配置
configs = [
(1, 1, 0), # stride=1, padding=0
(2, 1, 0), # stride=2, padding=0
(2, 1, 1), # stride=2, padding=1
]
x = torch.randn(1, 64, 4, 4)
for stride, k, p in configs:
deconv = nn.ConvTranspose2d(64, 64, kernel_size=k, stride=stride, padding=p)
out = deconv(x)
print(f"k={k}, s={stride}, p={p}: {x.shape} -> {out.shape}")
import torch.nn as nn
# 不同配置
configs = [
(1, 1, 0), # stride=1, padding=0
(2, 1, 0), # stride=2, padding=0
(2, 1, 1), # stride=2, padding=1
]
x = torch.randn(1, 64, 4, 4)
for stride, k, p in configs:
deconv = nn.ConvTranspose2d(64, 64, kernel_size=k, stride=stride, padding=p)
out = deconv(x)
print(f"k={k}, s={stride}, p={p}: {x.shape} -> {out.shape}")
使用场景
- 生成网络: GAN、VAE
- 语义分割: U-Net
- 上采样: 替代 pooling
注意:转置卷积不是卷积的逆运算,只是上采样的一种方式。

PyTorch torch.nn 参考手册