现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Parameter 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Parameter 是 PyTorch 中的可学习参数张量。

它是一个包装器,将普通张量转换为可学习的参数,会自动添加到模块参数列表中。

函数定义

torch.nn.Parameter(data=None, requires_grad=True)

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # 创建可学习参数
        self.weight = nn.Parameter(torch.randn(10, 5))
        self.bias = nn.Parameter(torch.zeros(5))

    def forward(self, x):
        return x @ self.weight.t() + self.bias

model = MyModule()
print("参数:", list(model.named_parameters()))
print("权重形状:", model.weight.shape)

示例 2: 替代 register_parameter

实例

import torch
import torch.nn as nn

# 方式1:使用 nn.Parameter(推荐)
class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.ones(5))

# 方式2:使用 register_parameter(等效)
class Net2(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_parameter('param', nn.Parameter(torch.ones(5)))

print("方式1 参数:", list(Net1().named_parameters()))
print("方式2 参数:", list(Net2().named_parameters()))

示例 3: 不可学习参数

实例

import torch
import torch.nn as nn

# requires_grad=False 创建不可学习参数
param_no_grad = nn.Parameter(torch.ones(5), requires_grad=False)

print("可训练:", param_no_grad.requires_grad)
print("仍然是 Parameter:", type(param_no_grad))

使用场景

  • 自定义层: 实现可学习的权重
  • 特殊模块: 非标准的参数
  • 直接访问: 操作方便

提示:nn.Parameter 是 tensor 的子类,会自动添加到 parameters() 中。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册