PyTorch torch.nn.Module 函数
torch.nn.Module 是 PyTorch 中所有神经网络模块的基类。
所有自定义的网络模型都应该继承自这个类,它提供了参数管理、设备迁移、模型保存等功能。
类定义
torch.nn.Module
主要属性
parameters(): 返回模型的所有可学习参数named_parameters(): 返回参数名称和值的迭代器children(): 返回模型的子模块named_children(): 返回子模块名称和模块的迭代器modules(): 返回所有模块state_dict(): 返回包含所有参数的字典
使用示例
示例 1: 创建自定义模块
继承 nn.Module 创建自定义网络:
实例
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SimpleNet, self).__init__()
# 定义网络层
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# 定义前向传播
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 实例化模型
model = SimpleNet(input_dim=784, hidden_dim=256, output_dim=10)
# 测试
x = torch.randn(32, 784)
output = model(x)
print("模型结构:")
print(model)
print("n输入形状:", x.shape)
print("输出形状:", output.shape)
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SimpleNet, self).__init__()
# 定义网络层
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# 定义前向传播
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 实例化模型
model = SimpleNet(input_dim=784, hidden_dim=256, output_dim=10)
# 测试
x = torch.randn(32, 784)
output = model(x)
print("模型结构:")
print(model)
print("n输入形状:", x.shape)
print("输出形状:", output.shape)
示例 2: 管理参数
访问和管理模型参数:
实例
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.bn1 = nn.BatchNorm2d(16)
self.fc = nn.Linear(16 * 14 * 14, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = Net()
# 统计参数数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("总参数量:", total_params)
print("可训练参数量:", trainable_params)
# 访问特定参数
print("nconv1 权重形状:", model.conv1.weight.shape)
print("fc 偏置形状:", model.fc.bias.shape)
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.bn1 = nn.BatchNorm2d(16)
self.fc = nn.Linear(16 * 14 * 14, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = Net()
# 统计参数数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("总参数量:", total_params)
print("可训练参数量:", trainable_params)
# 访问特定参数
print("nconv1 权重形状:", model.conv1.weight.shape)
print("fc 偏置形状:", model.fc.bias.shape)
示例 3: 模型保存与加载
保存和加载模型:
实例
import torch
import torch.nn as nn
import tempfile
import os
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = Net()
# 保存整个模型
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as f:
torch.save(model, f.name)
path_full = f.name
# 保存 state_dict(推荐方式)
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as f:
torch.save(model.state_dict(), f.name)
path_state = f.name
# 加载模型
loaded_model = Net()
loaded_model.load_state_dict(torch.load(path_state))
loaded_model.eval()
# 测试加载后的模型
x = torch.randn(2, 10)
output1 = model(x)
output2 = loaded_model(x)
print("原始输出:", output1[0].tolist())
print("加载后输出:", output2[0].tolist())
# 清理
os.remove(path_full)
os.remove(path_state)
import torch.nn as nn
import tempfile
import os
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = Net()
# 保存整个模型
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as f:
torch.save(model, f.name)
path_full = f.name
# 保存 state_dict(推荐方式)
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as f:
torch.save(model.state_dict(), f.name)
path_state = f.name
# 加载模型
loaded_model = Net()
loaded_model.load_state_dict(torch.load(path_state))
loaded_model.eval()
# 测试加载后的模型
x = torch.randn(2, 10)
output1 = model(x)
output2 = loaded_model(x)
print("原始输出:", output1[0].tolist())
print("加载后输出:", output2[0].tolist())
# 清理
os.remove(path_full)
os.remove(path_state)
示例 4: 设备迁移
在不同设备间迁移模型:
实例
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = Net()
# 查看当前设备
print("参数设备:", model.fc.weight.device)
# 移到 GPU(如果可用)
if torch.cuda.is_available():
model = model.cuda()
print("移至 GPU 后:", model.fc.weight.device)
# 移回 CPU
model = model.cpu()
print("移回 CPU 后:", model.fc.weight.device)
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = Net()
# 查看当前设备
print("参数设备:", model.fc.weight.device)
# 移到 GPU(如果可用)
if torch.cuda.is_available():
model = model.cuda()
print("移至 GPU 后:", model.fc.weight.device)
# 移回 CPU
model = model.cpu()
print("移回 CPU 后:", model.fc.weight.device)
示例 5: 使用 apply 初始化
使用 apply 递归初始化:
实例
import torch
import torch.nn as nn
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
return self.fc2(self.fc1(x))
model = Net()
model.apply(init_weights)
print("fc1 权重:", model.fc1.weight[0, :3].tolist())
print("fc2 权重:", model.fc2.weight[0, :3].tolist())
import torch.nn as nn
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 5)
def forward(self, x):
return self.fc2(self.fc1(x))
model = Net()
model.apply(init_weights)
print("fc1 权重:", model.fc1.weight[0, :3].tolist())
print("fc2 权重:", model.fc2.weight[0, :3].tolist())
示例 6: 复杂网络结构
构建带分支的网络:
实例
import torch
import torch.nn as nn
class BranchNet(nn.Module):
def __init__(self):
super(BranchNet, self).__init__()
# 主干
self.shared = nn.Linear(10, 20)
# 分支
self.branch_a = nn.Linear(20, 5)
self.branch_b = nn.Linear(20, 3)
def forward(self, x):
feat = self.shared(x)
out_a = self.branch_a(feat)
out_b = self.branch_b(feat)
return out_a, out_b
model = BranchNet()
x = torch.randn(4, 10)
out_a, out_b = model(x)
print("分支 A 输出:", out_a.shape)
print("分支 B 输出:", out_b.shape)
import torch.nn as nn
class BranchNet(nn.Module):
def __init__(self):
super(BranchNet, self).__init__()
# 主干
self.shared = nn.Linear(10, 20)
# 分支
self.branch_a = nn.Linear(20, 5)
self.branch_b = nn.Linear(20, 3)
def forward(self, x):
feat = self.shared(x)
out_a = self.branch_a(feat)
out_b = self.branch_b(feat)
return out_a, out_b
model = BranchNet()
x = torch.randn(4, 10)
out_a, out_b = model(x)
print("分支 A 输出:", out_a.shape)
print("分支 B 输出:", out_b.shape)
常见问题
Q1: 为什么 super().__init__() 必须调用?
它会调用父类的初始化方法,确保模块正确注册参数。
<h3 Q2: 如何查看模型结构?直接 print(model) 或使用 torchvision 的 summary 函数。
<h3 Q3: parameters() 和 modules() 的区别?parameters() 返回可学习参数,modules() 返回所有模块。
使用场景
nn.Module 是构建所有自定义神经网络的基础,主要应用场景包括:
- 定义网络结构: 任何自定义神经网络
- 模型管理: 保存、加载、迁移
- 参数管理: 优化器更新、梯度计算
提示:所有自定义 Module 都需要实现 forward 方法。

PyTorch torch.nn 参考手册