PyTorch torch.nn.GroupNorm 函数
torch.nn.GroupNorm 是 PyTorch 中的组归一化模块。
它将通道分成组进行归一化,不依赖 batch size,适合小 batch 或变长序列。
函数定义
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)
参数:
num_groups: 分组数量num_channels: 通道数,必须能被 num_groups 整除
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# GroupNorm: 4组,32通道
gn = nn.GroupNorm(num_groups=4, num_channels=32)
# 输入: (batch, channels, height, width)
x = torch.randn(8, 32, 16, 16)
output = gn(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
import torch.nn as nn
# GroupNorm: 4组,32通道
gn = nn.GroupNorm(num_groups=4, num_channels=32)
# 输入: (batch, channels, height, width)
x = torch.randn(8, 32, 16, 16)
output = gn(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
示例 2: 特殊情况 - InstanceNorm
实例
import torch
import torch.nn as nn
# InstanceNorm 是 num_groups=num_channels 的特殊情况
inorm = nn.InstanceNorm2d(32)
# 等价于 GroupNorm, num_groups=32, num_channels=32
gn_equiv = nn.GroupNorm(32, 32)
x = torch.randn(4, 32, 16, 16)
print("InstanceNorm2d:", inorm(x).mean().item())
print("GroupNorm:", gn_equiv(x).mean().item())
import torch.nn as nn
# InstanceNorm 是 num_groups=num_channels 的特殊情况
inorm = nn.InstanceNorm2d(32)
# 等价于 GroupNorm, num_groups=32, num_channels=32
gn_equiv = nn.GroupNorm(32, 32)
x = torch.randn(4, 32, 16, 16)
print("InstanceNorm2d:", inorm(x).mean().item())
print("GroupNorm:", gn_equiv(x).mean().item())
示例 3: ResNet 中的 GroupNorm
实例
import torch
import torch.nn as nn
# ResNet 结构使用 GroupNorm
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, num_groups=8):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.gn = nn.GroupNorm(num_groups, out_ch)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.gn(self.conv(x)))
block = ConvBlock(64, 128)
x = torch.randn(2, 64, 32, 32)
output = block(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
# ResNet 结构使用 GroupNorm
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, num_groups=8):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.gn = nn.GroupNorm(num_groups, out_ch)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.gn(self.conv(x)))
block = ConvBlock(64, 128)
x = torch.randn(2, 64, 32, 32)
output = block(x)
print("输入:", x.shape, "-> 输出:", output.shape)
示例 4: 不同 num_groups 的影响
实例
import torch
import torch.nn as nn
# 32通道,用不同的组数
for num_groups in [1, 2, 4, 8, 32]:
gn = nn.GroupNorm(num_groups, 32)
x = torch.randn(4, 32, 8, 8)
out = gn(x)
print(f"groups={num_groups}, 输出均值: {out.mean().item():.4f}")
import torch.nn as nn
# 32通道,用不同的组数
for num_groups in [1, 2, 4, 8, 32]:
gn = nn.GroupNorm(num_groups, 32)
x = torch.randn(4, 32, 8, 8)
out = gn(x)
print(f"groups={num_groups}, 输出均值: {out.mean().item():.4f}")
使用场景
- 小 batch: batch=1 时仍有效
- 视频/3D: BatchNorm 不稳定
- MobileNet: GroupNorm 替代 BatchNorm
注意:num_channels 必须能被 num_groups 整除。

PyTorch torch.nn 参考手册