现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.matmul 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.matmul 是 PyTorch 中用于执行矩阵乘法的函数。它支持不同维度的输入,能够处理一维、二维以及更高维度的张量乘法。

这是深度学习中最常用的运算之一,神经网络的前向传播本质上就是一系列矩阵乘法。

函数定义

torch.matmul(input, other, out=None)

参数:

  • input (Tensor): 第一个输入张量。
  • other (Tensor): 第二个输入张量。
  • out (Tensor, 可选): 输出张量。

返回值:

  • torch.Tensor: 返回矩阵乘法的结果。

使用示例

示例 1: 二维矩阵乘法

实例

import torch

# 创建两个二维矩阵
a = torch.randn(3, 4)  # 3x4 矩阵
b = torch.randn(4, 5)  # 4x5 矩阵

# 矩阵乘法
c = torch.matmul(a, b)

print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)

输出结果为:

a 的形状: torch.Size([3, 4])
b 的形状: torch.Size([4, 5])
c 的形状: torch.Size([3, 5])

示例 2: 向量与矩阵相乘

实例

import torch

# 创建向量和矩阵
vector = torch.randn(4)      # 4 维向量
matrix = torch.randn(4, 5)  # 4x5 矩阵

# 向量与矩阵相乘
result = torch.matmul(vector, matrix)

print("向量的形状:", vector.shape)
print("矩阵的形状:", matrix.shape)
print("结果的形状:", result.shape)
print(result)

输出结果为:

向量的形状: torch.Size([4])
矩阵的形状: torch.Size([4, 5])
结果的形状: torch.Size([5])
tensor([-0.7837,  0.3684, -0.6542, -0.4594,  1.5328])
</p>

<h3>示例 3: 批量矩阵乘法</h3>

<div class="example">
<h2 class="example">实例</h2>
<div class="example_code">
import torch

# 创建批量矩阵
batch_a = torch.randn(10, 3, 4)  # 10 个 3x4 矩阵
batch_b = torch.randn(10, 4, 5)  # 10 个 4x5 矩阵

# 批量矩阵乘法
batch_c = torch.matmul(batch_a, batch_b)

print("批量 a 的形状:", batch_a.shape)
print("批量 b 的形状:", batch_b.shape)
print("批量结果 c 的形状:", batch_c.shape)
</div> </div> <p>输出结果为:</p> <pre> 批量 a 的形状: torch.Size([10, 3, 4]) 批量 b 的形状: torch.Size([10, 4, 5]) 批量结果 c 的形状: torch.Size([10, 3, 5])

批量矩阵乘法是深度学习中的常见操作,例如在 Transformer 的注意力机制中。

示例 4: 神经网络中的矩阵乘法

实例

import torch

# 模拟神经网络层:输入 xW + b
x = torch.randn(32, 128)   # 批量大小 32,特征维度 128
W = torch.randn(128, 256)  # 权重矩阵
b = torch.randn(256)       # 偏置向量

# 线性变换:y = x @ W^T + b (在 PyTorch 中通常 W 是转置的)
# 这里演示 x @ W
y = torch.matmul(x, W) + b

print("输入 x 形状:", x.shape)
print("权重 W 形状:", W.shape)
print("输出 y 形状:", y.shape)

输出结果为:

输入 x 形状: torch.Size([32, 128])
权重 W 形状: torch.Size([128, 256])
输出 y 形状: torch.Size([32, 256])

这个示例模拟了神经网络中全连接层的前向传播过程。


注意事项

  • torch.matmul 支持广播,但两个输入的最后两个维度必须满足矩阵乘法的维度要求。
  • 对于二维矩阵乘法,也可以使用 torch.mm(),但 torch.matmul 更通用。
  • 注意区分 torch.matmul(矩阵乘法)和 torch.mul(逐元素乘法)。

Pytorch torch 参考手册 Pytorch torch 参考手册