PyTorch torch.chain_matmul 函数
torch.chain_matmul 是 PyTorch 中用于计算多个矩阵的链式乘法的函数。它通过最优矩阵乘法顺序来最小化计算成本。
函数定义
torch.chain_matmul(*matrices, out=None)
参数:
matrices(Tensor): 输入的矩阵序列。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回所有矩阵相乘的结果。
使用示例
实例
import torch
# 创建多个矩阵
A = torch.randn(10, 20)
B = torch.randn(20, 30)
C = torch.randn(30, 40)
D = torch.randn(40, 50)
# 链式矩阵乘法
result = torch.chain_matmul(A, B, C, D)
print("矩阵 A 形状:", A.shape)
print("矩阵 B 形状:", B.shape)
print("矩阵 C 形状:", C.shape)
print("矩阵 D 形状:", D.shape)
print("结果形状:", result.shape)
# 创建多个矩阵
A = torch.randn(10, 20)
B = torch.randn(20, 30)
C = torch.randn(30, 40)
D = torch.randn(40, 50)
# 链式矩阵乘法
result = torch.chain_matmul(A, B, C, D)
print("矩阵 A 形状:", A.shape)
print("矩阵 B 形状:", B.shape)
print("矩阵 C 形状:", C.shape)
print("矩阵 D 形状:", D.shape)
print("结果形状:", result.shape)
输出结果为:
矩阵 A 形状: torch.Size([10, 20]) 矩阵 B 形状: torch.Size([20, 30]) 矩阵 C 形状: torch.Size([30, 40]) 矩阵 D 形状: torch.Size([40, 50]) 结果形状: torch.Size([10, 50])
实例 - 使用列表
import torch
# 矩阵列表
matrices = [torch.randn(10, 20),
torch.randn(20, 30),
torch.randn(30, 40)]
# 使用列表作为参数
result = torch.chain_matmul(*matrices)
print("结果形状:", result.shape)
# 矩阵列表
matrices = [torch.randn(10, 20),
torch.randn(20, 30),
torch.randn(30, 40)]
# 使用列表作为参数
result = torch.chain_matmul(*matrices)
print("结果形状:", result.shape)

Pytorch torch 参考手册