PyTorch torch.nn.MultiheadAttention 函数
torch.nn.MultiheadAttention 是 PyTorch 中的多头注意力机制模块。
它是 Transformer 架构的核心组件,允许模型同时关注来自不同位置的不同表示子空间的信息。
函数定义
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, kdim=None, vdim=None, batch_first=True)
参数说明:
embed_dim(int): 输入嵌入维度。num_heads(int): 注意力头的数量。dropout(float): dropout 概率。默认为 0。kdim(int): 键向量的维度。默认为 None(与 embed_dim 相同)。vdim(int): 值向量的维度。默认为 None(与 embed_dim 相同)。batch_first(bool): 如果为 True,输入输出第一维是 batch。默认为 True。
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 创建多头注意力:512维,8个头
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8)
# 输入:batch=4,序列长度=100,维度=512
query = torch.randn(4, 100, 512)
key = torch.randn(4, 100, 512)
value = torch.randn(4, 100, 512)
# 前向传播
output, attn_weight = mha(query, key, value)
print("Query 形状:", query.shape)
print("Output 形状:", output.shape)
print("注意力权重形状:", attn_weight.shape)
import torch.nn as nn
# 创建多头注意力:512维,8个头
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8)
# 输入:batch=4,序列长度=100,维度=512
query = torch.randn(4, 100, 512)
key = torch.randn(4, 100, 512)
value = torch.randn(4, 100, 512)
# 前向传播
output, attn_weight = mha(query, key, value)
print("Query 形状:", query.shape)
print("Output 形状:", output.shape)
print("注意力权重形状:", attn_weight.shape)
示例 2: 自注意力
实例
import torch
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=256, num_heads=4)
# 相同的输入作为 Q、K、V(自注意力)
x = torch.randn(2, 50, 256)
# self-attention: q=k=v=x
output, weights = mha(x, x, x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("注意力权重形状:", weights.shape)
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=256, num_heads=4)
# 相同的输入作为 Q、K、V(自注意力)
x = torch.randn(2, 50, 256)
# self-attention: q=k=v=x
output, weights = mha(x, x, x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("注意力权重形状:", weights.shape)
示例 3: 带 mask 的注意力
实例
import torch
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=128, num_heads=4)
# 输入
x = torch.randn(1, 20, 128)
# 创建上三角 mask(用于解码器)
mask = torch.triu(torch.ones(20, 20), diagonal=1).bool()
output, _ = mha(x, x, x, attn_mask=mask)
print("输入形状:", x.shape)
print("Output 形状:", output.shape)
print("Mask 形状:", mask.shape)
import torch.nn as nn
mha = nn.MultiheadAttention(embed_dim=128, num_heads=4)
# 输入
x = torch.randn(1, 20, 128)
# 创建上三角 mask(用于解码器)
mask = torch.triu(torch.ones(20, 20), diagonal=1).bool()
output, _ = mha(x, x, x, attn_mask=mask)
print("输入形状:", x.shape)
print("Output 形状:", output.shape)
print("Mask 形状:", mask.shape)
示例 4: 完整的 Transformer 编码器层
实例
import torch
import torch.nn as nn
class TransformerLayer(nn.Module):
def __init__(self, d_model, nhead):
super(TransformerLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
# 自注意力 + 残差
attn_out, _ = self.self_attn(x, x, x)
x = self.norm1(x + attn_out)
# FFN + 残差
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
# 测试
layer = TransformerLayer(d_model=512, nhead=8)
x = torch.randn(4, 100, 512)
output = layer(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
import torch.nn as nn
class TransformerLayer(nn.Module):
def __init__(self, d_model, nhead):
super(TransformerLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
# 自注意力 + 残差
attn_out, _ = self.self_attn(x, x, x)
x = self.norm1(x + attn_out)
# FFN + 残差
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
# 测试
layer = TransformerLayer(d_model=512, nhead=8)
x = torch.randn(4, 100, 512)
output = layer(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
示例 5: 查看注意力权重
实例
import torch
import torch.nn as nn
import numpy as np
mha = nn.MultiheadAttention(embed_dim=64, num_heads=2, batch_first=True)
# 简短视频序列
x = torch.randn(1, 5, 64)
_, attn = mha(x, x, x)
attn = attn.squeeze(0) # 去掉 batch
print("第一个头的注意力权重 (前3个位置):")
print(attn[0, :3, :3].tolist())
print("n可视化 - 位置0对所有位置的注意力:")
print(np.array2string(attn[0, 0].numpy(), precision=2))
import torch.nn as nn
import numpy as np
mha = nn.MultiheadAttention(embed_dim=64, num_heads=2, batch_first=True)
# 简短视频序列
x = torch.randn(1, 5, 64)
_, attn = mha(x, x, x)
attn = attn.squeeze(0) # 去掉 batch
print("第一个头的注意力权重 (前3个位置):")
print(attn[0, :3, :3].tolist())
print("n可视化 - 位置0对所有位置的注意力:")
print(np.array2string(attn[0, 0].numpy(), precision=2))
常见问题
Q1: num_heads 如何选择?
embed_dim 必须能被 num_heads 整除。常用值:8、12、16。
Q2: 为什么需要 Q、K、V 三个矩阵?
允许模型学习不同的投影,增强表达能力。
Q3: key_padding_mask 是什么?
用于遮盖 padding 位置,避免注意力计算到 padding。
使用场景
- Transformer: 编码器和解码器
- 自注意力模型: BERT、GPT
- 序列建模: 替代 RNN
提示:batch_first=True 时,输入形状为 (batch, seq, embed_dim)。

PyTorch torch.nn 参考手册