PyTorch torch.einsum 函数
torch.einsum 是 PyTorch 中用于爱因斯坦求和约定的函数,可以简洁地表示各种张量运算。
函数定义
torch.einsum(equation, *operands)
使用示例
实例
import torch
# 矩阵转置
A = torch.randn(2, 3)
AT = torch.einsum('ij->ji', A)
print("转置形状:", AT.shape)
# 矩阵乘法
B = torch.randn(3, 4)
C = torch.einsum('ij,jk->ik', A, B)
print("乘法形状:", C.shape)
# 点积
a = torch.randn(3)
b = torch.randn(3)
dot = torch.einsum('i,i->', a, b)
print("点积:", dot.item())
# 矩阵转置
A = torch.randn(2, 3)
AT = torch.einsum('ij->ji', A)
print("转置形状:", AT.shape)
# 矩阵乘法
B = torch.randn(3, 4)
C = torch.einsum('ij,jk->ik', A, B)
print("乘法形状:", C.shape)
# 点积
a = torch.randn(3)
b = torch.randn(3)
dot = torch.einsum('i,i->', a, b)
print("点积:", dot.item())

Pytorch torch 参考手册