现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.baddbmm 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.baddbmm 是 PyTorch 中用于执行批量矩阵乘法并添加到输入批量矩阵的函数。它对 batch1 和 batch2 中的每一对矩阵进行矩阵乘法,然后与 input 中对应的矩阵相加。

函数定义

torch.baddbmm(input, batch1, batch2, *, beta=1.0, alpha=1.0, out=None)

参数:

  • input (Tensor): 输入批量矩阵,形状为 (b, n, p)。
  • batch1 (Tensor): 第一个批量矩阵,形状为 (b, n, m)。
  • batch2 (Tensor): 第二个批量矩阵,形状为 (b, m, p)。
  • beta (float, 可选): 乘以 input 的系数,默认为 1.0。
  • alpha (float, 可选): 乘以 batch1 @ batch2 结果的系数,默认为 1.0。
  • out (Tensor, 可选): 输出张量。

返回值:

  • torch.Tensor: 返回批量矩阵乘法结果与输入批量矩阵之和,形状为 (b, n, p)。

使用示例

实例

import torch

# 创建输入批量矩阵和批量矩阵
input = torch.randn(10, 3, 3)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 3)

# 执行 baddbmm
result = torch.baddbmm(input, batch1, batch2)

print("输入批量矩阵形状:", input.shape)
print("批量矩阵1形状:", batch1.shape)
print("批量矩阵2形状:", batch2.shape)
print("结果形状:", result.shape)
print(result[0])  # 打印第一个结果

输出结果为:

输入批量矩阵形状: torch.Size([10, 3, 3])
批量矩阵1形状: torch.Size([10, 3, 4])
批量矩阵2形状: torch.Size([10, 4, 3])
结果形状: torch.Size([10, 3, 3])
tensor([[-1.0917,  0.0174, -0.6599],
        [ 1.0403, -0.3905, -0.6314],
        [ 0.0988, -0.3047,  0.4841]])

Pytorch torch 参考手册 Pytorch torch 参考手册