现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Conv3d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Conv3d 是 PyTorch 中的三维卷积模块。

它处理三维输入,如视频或医学图像(深度 x 高 x 宽)。

函数定义

torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

输入形状

(batch, channels, depth, height, width)

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 3D 卷积
conv3d = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=3, padding=1)

# 输入:batch=2,通道=3,深度=16,高=32,宽=32
x = torch.randn(2, 3, 16, 32, 32)
output = conv3d(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 2: 视频分类

实例

import torch
import torch.nn as nn

class VideoCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(VideoCNN, self).__init__()
        self.conv1 = nn.Conv3d(3, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool3d(2)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.gap = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

model = VideoCNN()
x = torch.randn(4, 3, 16, 112, 112)  # 16帧视频
output = model(x)

print("视频:", x.shape, "-> 类别:", output.shape)

示例 3: 参数计算

实例

import torch
import torch.nn as nn

conv = nn.Conv3d(64, 128, kernel_size=3, padding=1)

# 参数量计算: out_ch * in_ch * k_d * k_h * k_w + bias
params = 128 * 64 * 3 * 3 * 3
print("参数量:", params)
print("权重形状:", conv.weight.shape)

使用场景

  • 视频处理: 动作识别
  • 医学影像: CT、MRI
  • 3D 分割: 体素数据

注意:3D 卷积计算量大,需要较多显存。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册