现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.MultiheadAttention 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.MultiheadAttention 是 PyTorch 中的多头注意力机制模块。

它是 Transformer 架构的核心组件,允许模型同时关注来自不同位置的不同表示子空间的信息。

函数定义

torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, kdim=None, vdim=None, batch_first=True)

参数说明:

  • embed_dim (int): 输入嵌入维度。
  • num_heads (int): 注意力头的数量。
  • dropout (float): dropout 概率。默认为 0。
  • kdim (int): 键向量的维度。默认为 None(与 embed_dim 相同)。
  • vdim (int): 值向量的维度。默认为 None(与 embed_dim 相同)。
  • batch_first (bool): 如果为 True,输入输出第一维是 batch。默认为 True。

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 创建多头注意力:512维,8个头
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8)

# 输入:batch=4,序列长度=100,维度=512
query = torch.randn(4, 100, 512)
key = torch.randn(4, 100, 512)
value = torch.randn(4, 100, 512)

# 前向传播
output, attn_weight = mha(query, key, value)

print("Query 形状:", query.shape)
print("Output 形状:", output.shape)
print("注意力权重形状:", attn_weight.shape)

示例 2: 自注意力

实例

import torch
import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=256, num_heads=4)

# 相同的输入作为 Q、K、V(自注意力)
x = torch.randn(2, 50, 256)

# self-attention: q=k=v=x
output, weights = mha(x, x, x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)
print("注意力权重形状:", weights.shape)

示例 3: 带 mask 的注意力

实例

import torch
import torch.nn as nn

mha = nn.MultiheadAttention(embed_dim=128, num_heads=4)

# 输入
x = torch.randn(1, 20, 128)

# 创建上三角 mask(用于解码器)
mask = torch.triu(torch.ones(20, 20), diagonal=1).bool()

output, _ = mha(x, x, x, attn_mask=mask)

print("输入形状:", x.shape)
print("Output 形状:", output.shape)
print("Mask 形状:", mask.shape)

示例 4: 完整的 Transformer 编码器层

实例

import torch
import torch.nn as nn

class TransformerLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super(TransformerLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x):
        # 自注意力 + 残差
        attn_out, _ = self.self_attn(x, x, x)
        x = self.norm1(x + attn_out)

        # FFN + 残差
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x

# 测试
layer = TransformerLayer(d_model=512, nhead=8)
x = torch.randn(4, 100, 512)
output = layer(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 5: 查看注意力权重

实例

import torch
import torch.nn as nn
import numpy as np

mha = nn.MultiheadAttention(embed_dim=64, num_heads=2, batch_first=True)

# 简短视频序列
x = torch.randn(1, 5, 64)

_, attn = mha(x, x, x)
attn = attn.squeeze(0)  # 去掉 batch

print("第一个头的注意力权重 (前3个位置):")
print(attn[0, :3, :3].tolist())

print("n可视化 - 位置0对所有位置的注意力:")
print(np.array2string(attn[0, 0].numpy(), precision=2))

常见问题

Q1: num_heads 如何选择?

embed_dim 必须能被 num_heads 整除。常用值:8、12、16。

Q2: 为什么需要 Q、K、V 三个矩阵?

允许模型学习不同的投影,增强表达能力。

Q3: key_padding_mask 是什么?

用于遮盖 padding 位置,避免注意力计算到 padding。


使用场景

  • Transformer: 编码器和解码器
  • 自注意力模型: BERT、GPT
  • 序列建模: 替代 RNN

提示:batch_first=True 时,输入形状为 (batch, seq, embed_dim)。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册