现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.ConvTranspose2d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.ConvTranspose2d 是 PyTorch 中的二维转置卷积,也称为反卷积或上采样卷积。

它用于上采样特征图,是生成网络和分割网络的关键组件。

函数定义

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)

参数

  • output_padding: 额外输出填充

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 转置卷积:上采样
deconv = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)

x = torch.randn(1, 64, 16, 16)
output = deconv(x)

print("输入:", x.shape, "-> 输出:", output.shape)

示例 2: 生成网络

实例

import torch
import torch.nn as nn

# 简化的 DCGAN 生成器
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),  # 4->8
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 8->16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 16->32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),      # 32->64
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x).view(-1, 512, 4, 4)
        return self.deconv(x)

gen = Generator()
z = torch.randn(1, 100)
img = gen(z)

print("噪声:", z.shape, "-> 图像:", img.shape)

示例 3: 分割网络上采样

实例

import torch
import torch.nn as nn

# U-Net 解码器部分
decoder = nn.Sequential(
    nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
    nn.Conv2d(128, 128, 3, padding=1),
    nn.ReLU(),
    nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
    nn.Conv2d(64, 64, 3, padding=1),
    nn.ReLU()
)

x = torch.randn(1, 256, 8, 8)
output = decoder(x)

print("输入:", x.shape, "-> 输出:", output.shape)

示例 4: stride=2 的输出计算

实例

import torch
import torch.nn as nn

# 不同配置
configs = [
    (1, 1, 0),  # stride=1, padding=0
    (2, 1, 0),  # stride=2, padding=0
    (2, 1, 1),  # stride=2, padding=1
]

x = torch.randn(1, 64, 4, 4)

for stride, k, p in configs:
    deconv = nn.ConvTranspose2d(64, 64, kernel_size=k, stride=stride, padding=p)
    out = deconv(x)
    print(f"k={k}, s={stride}, p={p}: {x.shape} -> {out.shape}")

使用场景

  • 生成网络: GAN、VAE
  • 语义分割: U-Net
  • 上采样: 替代 pooling

注意:转置卷积不是卷积的逆运算,只是上采样的一种方式。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册