PyTorch torch.baddbmm 函数
torch.baddbmm 是 PyTorch 中用于执行批量矩阵乘法并添加到输入批量矩阵的函数。它对 batch1 和 batch2 中的每一对矩阵进行矩阵乘法,然后与 input 中对应的矩阵相加。
函数定义
torch.baddbmm(input, batch1, batch2, *, beta=1.0, alpha=1.0, out=None)
参数:
input(Tensor): 输入批量矩阵,形状为 (b, n, p)。batch1(Tensor): 第一个批量矩阵,形状为 (b, n, m)。batch2(Tensor): 第二个批量矩阵,形状为 (b, m, p)。beta(float, 可选): 乘以 input 的系数,默认为 1.0。alpha(float, 可选): 乘以 batch1 @ batch2 结果的系数,默认为 1.0。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回批量矩阵乘法结果与输入批量矩阵之和,形状为 (b, n, p)。
使用示例
实例
import torch
# 创建输入批量矩阵和批量矩阵
input = torch.randn(10, 3, 3)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 3)
# 执行 baddbmm
result = torch.baddbmm(input, batch1, batch2)
print("输入批量矩阵形状:", input.shape)
print("批量矩阵1形状:", batch1.shape)
print("批量矩阵2形状:", batch2.shape)
print("结果形状:", result.shape)
print(result[0]) # 打印第一个结果
# 创建输入批量矩阵和批量矩阵
input = torch.randn(10, 3, 3)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 3)
# 执行 baddbmm
result = torch.baddbmm(input, batch1, batch2)
print("输入批量矩阵形状:", input.shape)
print("批量矩阵1形状:", batch1.shape)
print("批量矩阵2形状:", batch2.shape)
print("结果形状:", result.shape)
print(result[0]) # 打印第一个结果
输出结果为:
输入批量矩阵形状: torch.Size([10, 3, 3])
批量矩阵1形状: torch.Size([10, 3, 4])
批量矩阵2形状: torch.Size([10, 4, 3])
结果形状: torch.Size([10, 3, 3])
tensor([[-1.0917, 0.0174, -0.6599],
[ 1.0403, -0.3905, -0.6314],
[ 0.0988, -0.3047, 0.4841]])

Pytorch torch 参考手册