PyTorch torch.mm 函数
torch.mm 是 PyTorch 中用于执行二维矩阵乘法的函数。
函数定义
torch.mm(input, mat2, out)
使用示例
实例
import torch
A = torch.randn(2, 3)
B = torch.randn(3, 4)
# 矩阵乘法
C = torch.mm(A, B)
print("A 形状:", A.shape)
print("B 形状:", B.shape)
print("C 形状:", C.shape)
A = torch.randn(2, 3)
B = torch.randn(3, 4)
# 矩阵乘法
C = torch.mm(A, B)
print("A 形状:", A.shape)
print("B 形状:", B.shape)
print("C 形状:", C.shape)
输出结果为:
A 形状: torch.Size([2, 3]) B 形状: torch.Size([3, 4]) C 形状: torch.Size([2, 4])

Pytorch torch 参考手册