现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.InstanceNorm2d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.InstanceNorm2d 是 PyTorch 中的实例归一化模块。

它对每个样本的每个通道独立归一化,常用于风格迁移。

函数定义

torch.nn.InstanceNorm2d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)

参数

  • num_features: 通道数
  • affine: 是否使用可学习参数

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

inorm = nn.InstanceNorm2d(64)

# 输入
x = torch.randn(4, 64, 16, 16)
output = inorm(x)

print("输入:", x.shape, "-> 输出:", output.shape)

示例 2: 风格迁移

实例

import torch
import torch.nn as nn

# 风格迁移网络常用 InstanceNorm
class StyleNet(nn.Module):
    def __init__(self):
        super(StyleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.in1 = nn.InstanceNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.in2 = nn.InstanceNorm2d(64)

    def forward(self, x):
        x = self.conv1(x)
        x = self.in1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = self.in2(x)
        return x

net = StyleNet()
x = torch.randn(1, 3, 256, 256)
output = net(x)

print("输入:", x.shape, "-> 输出:", output.shape)

示例 3: 对比 BatchNorm

实例

import torch
import torch.nn as nn

bn = nn.BatchNorm2d(32)
inorm = nn.InstanceNorm2d(32)

x = torch.randn(4, 32, 8, 8)

print("BatchNorm mean:", bn(x).mean(dim=(0, 2, 3))[:5].tolist())
print("InstanceNorm mean:", inorm(x).mean(dim=(0, 2, 3))[:5].tolist())

使用场景

  • 风格迁移: NIN、AdaIN
  • 纹理合成
  • 小 batch: 适合 batch=1

注意:默认 affine=False,不含可学习参数。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册