现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.ModuleList 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.ModuleList 是 PyTorch 中用于存储模块列表的容器。

它类似于 Python 列表,但会自动注册所有子模块。

函数定义

torch.nn.ModuleList(modules=None)

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 20),
            nn.Linear(20, 20),
            nn.Linear(20, 5)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = Net()
x = torch.randn(4, 10)
output = model(x)

print("输入:", x.shape)
print("输出:", output.shape)
print("参数数量:", sum(p.numel() for p in model.parameters()))

示例 2: 动态构建层

实例

import torch
import torch.nn as nn

class DynamicNet(nn.Module):
    def __init__(self, num_layers, dim):
        super(DynamicNet, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = dim if i == 0 else dim
            self.layers.append(nn.Linear(in_dim, dim))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:
                x = torch.relu(x)
        return x

model = DynamicNet(num_layers=5, dim=64)
print("层数:", len(model.layers))
print("总参数:", sum(p.numel() for p in model.parameters()))

示例 3: 与 Sequential 对比

实例

import torch
import torch.nn as nn

# ModuleList - 灵活
mlist = nn.ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])

# Sequential - 固定顺序
seq = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))

print("ModuleList 索引访问:", mlist[0])
print("Sequential 索引访问:", seq[0])
print("nModuleList 可遍历但无 forward")
print("Sequential 直接调用")

使用场景

  • 多层网络: 动态构建
  • 循环结构: 共享层
  • 复杂模型: 需要自定义控制流

注意:必须使用 nn.ModuleList 而非 Python 列表。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册