现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.GroupNorm 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.GroupNorm 是 PyTorch 中的组归一化模块。

它将通道分成组进行归一化,不依赖 batch size,适合小 batch 或变长序列。

函数定义

torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)

参数:

  • num_groups: 分组数量
  • num_channels: 通道数,必须能被 num_groups 整除

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# GroupNorm: 4组,32通道
gn = nn.GroupNorm(num_groups=4, num_channels=32)

# 输入: (batch, channels, height, width)
x = torch.randn(8, 32, 16, 16)
output = gn(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 2: 特殊情况 - InstanceNorm

实例

import torch
import torch.nn as nn

# InstanceNorm 是 num_groups=num_channels 的特殊情况
inorm = nn.InstanceNorm2d(32)

# 等价于 GroupNorm, num_groups=32, num_channels=32
gn_equiv = nn.GroupNorm(32, 32)

x = torch.randn(4, 32, 16, 16)

print("InstanceNorm2d:", inorm(x).mean().item())
print("GroupNorm:", gn_equiv(x).mean().item())

示例 3: ResNet 中的 GroupNorm

实例

import torch
import torch.nn as nn

# ResNet 结构使用 GroupNorm
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, num_groups=8):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.gn = nn.GroupNorm(num_groups, out_ch)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.gn(self.conv(x)))

block = ConvBlock(64, 128)
x = torch.randn(2, 64, 32, 32)
output = block(x)

print("输入:", x.shape, "-> 输出:", output.shape)

示例 4: 不同 num_groups 的影响

实例

import torch
import torch.nn as nn

# 32通道,用不同的组数
for num_groups in [1, 2, 4, 8, 32]:
    gn = nn.GroupNorm(num_groups, 32)
    x = torch.randn(4, 32, 8, 8)
    out = gn(x)
    print(f"groups={num_groups}, 输出均值: {out.mean().item():.4f}")

使用场景

  • 小 batch: batch=1 时仍有效
  • 视频/3D: BatchNorm 不稳定
  • MobileNet: GroupNorm 替代 BatchNorm

注意:num_channels 必须能被 num_groups 整除。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册