PyTorch torch.addbmm 函数
torch.addbmm 是 PyTorch 中用于将批量矩阵乘法的结果添加到输入矩阵的函数。它对 batch1 和 batch2 中的每一对矩阵进行矩阵乘法,然后将所有结果求和后添加到 input 中。
函数定义
torch.addbmm(input, batch1, batch2, *, beta=1.0, alpha=1.0, out=None)
参数:
input(Tensor): 输入矩阵,被添加到结果中。batch1(Tensor): 第一个批量矩阵,形状为 (b, n, m)。batch2(Tensor): 第二个批量矩阵,形状为 (b, m, p)。beta(float, 可选): 乘以 input 的系数,默认为 1.0。alpha(float, 可选): 乘以 batch1 @ batch2 结果的系数,默认为 1.0。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回批量矩阵乘法结果与输入矩阵之和。
使用示例
实例
import torch
# 创建输入矩阵和批量矩阵
input = torch.randn(3, 3)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 3)
# 执行 addbmm
result = torch.addbmm(input, batch1, batch2)
print("输入矩阵形状:", input.shape)
print("批量矩阵1形状:", batch1.shape)
print("批量矩阵2形状:", batch2.shape)
print("结果形状:", result.shape)
print(result)
# 创建输入矩阵和批量矩阵
input = torch.randn(3, 3)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 3)
# 执行 addbmm
result = torch.addbmm(input, batch1, batch2)
print("输入矩阵形状:", input.shape)
print("批量矩阵1形状:", batch1.shape)
print("批量矩阵2形状:", batch2.shape)
print("结果形状:", result.shape)
print(result)
输出结果为:
输入矩阵形状: torch.Size([3, 3])
批量矩阵1形状: torch.Size([10, 3, 4])
批量矩阵2形状: torch.Size([10, 4, 3])
结果形状: torch.Size([3, 3])
tensor([[-0.2875, -0.6518, 0.1207],
[-0.2746, 0.5563, -0.2281],
[ 0.3315, 0.1649, 0.2145]])

Pytorch torch 参考手册