现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.addmm 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.addmm 是 PyTorch 中用于将矩阵乘法的结果添加到输入矩阵的函数。它执行 mat1 @ mat2 的矩阵乘法,然后将结果与 input 相加。

函数定义

torch.addmm(input, mat1, mat2, *, beta=1.0, alpha=1.0, out=None)

参数:

  • input (Tensor): 输入矩阵,被添加到结果中。
  • mat1 (Tensor): 第一个矩阵,形状为 (n, m)。
  • mat2 (Tensor): 第二个矩阵,形状为 (m, p)。
  • beta (float, 可选): 乘以 input 的系数,默认为 1.0。
  • alpha (float, 可选): 乘以 mat1 @ mat2 结果的系数,默认为 1.0。
  • out (Tensor, 可选): 输出张量。

返回值:

  • torch.Tensor: 返回矩阵乘法结果与输入矩阵之和,形状为 (n, p)。

使用示例

实例

import torch

# 创建输入矩阵和两个矩阵
input = torch.randn(3, 3)
mat1 = torch.randn(3, 4)
mat2 = torch.randn(4, 3)

# 执行 addmm
result = torch.addmm(input, mat1, mat2)

print("输入矩阵形状:", input.shape)
print("矩阵1形状:", mat1.shape)
print("矩阵2形状:", mat2.shape)
print("结果形状:", result.shape)
print(result)

输出结果为:

输入矩阵形状: torch.Size([3, 3])
矩阵1形状: torch.Size([3, 4])
矩阵2形状: torch.Size([4, 3])
结果形状: torch.Size([3, 3])
tensor([[ 0.4692, -0.2864, -0.6013],
        [ 1.5525,  0.1233, -0.0182],
        [-0.3956,  0.6267,  0.3580]])

实例 - 使用 alpha 和 beta 参数

import torch

input = torch.randn(3, 3)
mat1 = torch.randn(3, 4)
mat2 = torch.randn(4, 3)

# 使用 alpha 和 beta 参数
result = torch.addmm(input, mat1, mat2, beta=0.5, alpha=2.0)

# 等价于: result = 0.5 * input + 2.0 * (mat1 @ mat2)
print(result)

Pytorch torch 参考手册 Pytorch torch 参考手册