现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.chain_matmul 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.chain_matmul 是 PyTorch 中用于计算多个矩阵的链式乘法的函数。它通过最优矩阵乘法顺序来最小化计算成本。

函数定义

torch.chain_matmul(*matrices, out=None)

参数:

  • matrices (Tensor): 输入的矩阵序列。
  • out (Tensor, 可选): 输出张量。

返回值:

  • torch.Tensor: 返回所有矩阵相乘的结果。

使用示例

实例

import torch

# 创建多个矩阵
A = torch.randn(10, 20)
B = torch.randn(20, 30)
C = torch.randn(30, 40)
D = torch.randn(40, 50)

# 链式矩阵乘法
result = torch.chain_matmul(A, B, C, D)

print("矩阵 A 形状:", A.shape)
print("矩阵 B 形状:", B.shape)
print("矩阵 C 形状:", C.shape)
print("矩阵 D 形状:", D.shape)
print("结果形状:", result.shape)

输出结果为:

矩阵 A 形状: torch.Size([10, 20])
矩阵 B 形状: torch.Size([20, 30])
矩阵 C 形状: torch.Size([30, 40])
矩阵 D 形状: torch.Size([40, 50])
结果形状: torch.Size([10, 50])

实例 - 使用列表

import torch

# 矩阵列表
matrices = [torch.randn(10, 20),
            torch.randn(20, 30),
            torch.randn(30, 40)]

# 使用列表作为参数
result = torch.chain_matmul(*matrices)
print("结果形状:", result.shape)

Pytorch torch 参考手册 Pytorch torch 参考手册