PyTorch torch.nn.Parameter 函数
torch.nn.Parameter 是 PyTorch 中的可学习参数张量。
它是一个包装器,将普通张量转换为可学习的参数,会自动添加到模块参数列表中。
函数定义
torch.nn.Parameter(data=None, requires_grad=True)
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# 创建可学习参数
self.weight = nn.Parameter(torch.randn(10, 5))
self.bias = nn.Parameter(torch.zeros(5))
def forward(self, x):
return x @ self.weight.t() + self.bias
model = MyModule()
print("参数:", list(model.named_parameters()))
print("权重形状:", model.weight.shape)
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# 创建可学习参数
self.weight = nn.Parameter(torch.randn(10, 5))
self.bias = nn.Parameter(torch.zeros(5))
def forward(self, x):
return x @ self.weight.t() + self.bias
model = MyModule()
print("参数:", list(model.named_parameters()))
print("权重形状:", model.weight.shape)
示例 2: 替代 register_parameter
实例
import torch
import torch.nn as nn
# 方式1:使用 nn.Parameter(推荐)
class Net1(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.ones(5))
# 方式2:使用 register_parameter(等效)
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.register_parameter('param', nn.Parameter(torch.ones(5)))
print("方式1 参数:", list(Net1().named_parameters()))
print("方式2 参数:", list(Net2().named_parameters()))
import torch.nn as nn
# 方式1:使用 nn.Parameter(推荐)
class Net1(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.ones(5))
# 方式2:使用 register_parameter(等效)
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.register_parameter('param', nn.Parameter(torch.ones(5)))
print("方式1 参数:", list(Net1().named_parameters()))
print("方式2 参数:", list(Net2().named_parameters()))
示例 3: 不可学习参数
实例
import torch
import torch.nn as nn
# requires_grad=False 创建不可学习参数
param_no_grad = nn.Parameter(torch.ones(5), requires_grad=False)
print("可训练:", param_no_grad.requires_grad)
print("仍然是 Parameter:", type(param_no_grad))
import torch.nn as nn
# requires_grad=False 创建不可学习参数
param_no_grad = nn.Parameter(torch.ones(5), requires_grad=False)
print("可训练:", param_no_grad.requires_grad)
print("仍然是 Parameter:", type(param_no_grad))
使用场景
- 自定义层: 实现可学习的权重
- 特殊模块: 非标准的参数
- 直接访问: 操作方便
提示:nn.Parameter 是 tensor 的子类,会自动添加到 parameters() 中。

PyTorch torch.nn 参考手册