PyTorch torch.nn.InstanceNorm2d 函数
torch.nn.InstanceNorm2d 是 PyTorch 中的实例归一化模块。
它对每个样本的每个通道独立归一化,常用于风格迁移。
函数定义
torch.nn.InstanceNorm2d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
参数
num_features: 通道数affine: 是否使用可学习参数
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
inorm = nn.InstanceNorm2d(64)
# 输入
x = torch.randn(4, 64, 16, 16)
output = inorm(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
inorm = nn.InstanceNorm2d(64)
# 输入
x = torch.randn(4, 64, 16, 16)
output = inorm(x)
print("输入:", x.shape, "-> 输出:", output.shape)
示例 2: 风格迁移
实例
import torch
import torch.nn as nn
# 风格迁移网络常用 InstanceNorm
class StyleNet(nn.Module):
def __init__(self):
super(StyleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.in1 = nn.InstanceNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.in2 = nn.InstanceNorm2d(64)
def forward(self, x):
x = self.conv1(x)
x = self.in1(x)
x = torch.relu(x)
x = self.conv2(x)
x = self.in2(x)
return x
net = StyleNet()
x = torch.randn(1, 3, 256, 256)
output = net(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
# 风格迁移网络常用 InstanceNorm
class StyleNet(nn.Module):
def __init__(self):
super(StyleNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.in1 = nn.InstanceNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.in2 = nn.InstanceNorm2d(64)
def forward(self, x):
x = self.conv1(x)
x = self.in1(x)
x = torch.relu(x)
x = self.conv2(x)
x = self.in2(x)
return x
net = StyleNet()
x = torch.randn(1, 3, 256, 256)
output = net(x)
print("输入:", x.shape, "-> 输出:", output.shape)
示例 3: 对比 BatchNorm
实例
import torch
import torch.nn as nn
bn = nn.BatchNorm2d(32)
inorm = nn.InstanceNorm2d(32)
x = torch.randn(4, 32, 8, 8)
print("BatchNorm mean:", bn(x).mean(dim=(0, 2, 3))[:5].tolist())
print("InstanceNorm mean:", inorm(x).mean(dim=(0, 2, 3))[:5].tolist())
import torch.nn as nn
bn = nn.BatchNorm2d(32)
inorm = nn.InstanceNorm2d(32)
x = torch.randn(4, 32, 8, 8)
print("BatchNorm mean:", bn(x).mean(dim=(0, 2, 3))[:5].tolist())
print("InstanceNorm mean:", inorm(x).mean(dim=(0, 2, 3))[:5].tolist())
使用场景
- 风格迁移: NIN、AdaIN
- 纹理合成
- 小 batch: 适合 batch=1
注意:默认 affine=False,不含可学习参数。

PyTorch torch.nn 参考手册