PyTorch torch.nn.ModuleList 函数
torch.nn.ModuleList 是 PyTorch 中用于存储模块列表的容器。
它类似于 Python 列表,但会自动注册所有子模块。
函数定义
torch.nn.ModuleList(modules=None)
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 20),
nn.Linear(20, 20),
nn.Linear(20, 5)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
model = Net()
x = torch.randn(4, 10)
output = model(x)
print("输入:", x.shape)
print("输出:", output.shape)
print("参数数量:", sum(p.numel() for p in model.parameters()))
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 20),
nn.Linear(20, 20),
nn.Linear(20, 5)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
model = Net()
x = torch.randn(4, 10)
output = model(x)
print("输入:", x.shape)
print("输出:", output.shape)
print("参数数量:", sum(p.numel() for p in model.parameters()))
示例 2: 动态构建层
实例
import torch
import torch.nn as nn
class DynamicNet(nn.Module):
def __init__(self, num_layers, dim):
super(DynamicNet, self).__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_dim = dim if i == 0 else dim
self.layers.append(nn.Linear(in_dim, dim))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = torch.relu(x)
return x
model = DynamicNet(num_layers=5, dim=64)
print("层数:", len(model.layers))
print("总参数:", sum(p.numel() for p in model.parameters()))
import torch.nn as nn
class DynamicNet(nn.Module):
def __init__(self, num_layers, dim):
super(DynamicNet, self).__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
in_dim = dim if i == 0 else dim
self.layers.append(nn.Linear(in_dim, dim))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = torch.relu(x)
return x
model = DynamicNet(num_layers=5, dim=64)
print("层数:", len(model.layers))
print("总参数:", sum(p.numel() for p in model.parameters()))
示例 3: 与 Sequential 对比
实例
import torch
import torch.nn as nn
# ModuleList - 灵活
mlist = nn.ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
# Sequential - 固定顺序
seq = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
print("ModuleList 索引访问:", mlist[0])
print("Sequential 索引访问:", seq[0])
print("nModuleList 可遍历但无 forward")
print("Sequential 直接调用")
import torch.nn as nn
# ModuleList - 灵活
mlist = nn.ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])
# Sequential - 固定顺序
seq = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
print("ModuleList 索引访问:", mlist[0])
print("Sequential 索引访问:", seq[0])
print("nModuleList 可遍历但无 forward")
print("Sequential 直接调用")
使用场景
- 多层网络: 动态构建
- 循环结构: 共享层
- 复杂模型: 需要自定义控制流
注意:必须使用 nn.ModuleList 而非 Python 列表。

PyTorch torch.nn 参考手册