现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.TransformerEncoderLayer 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


nn.TransformerEncoderLayer 是 Transformer 编码器的单层结构。

它包含自注意力和前馈网络,是构成完整编码器的基本单元。

函数定义

torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True)

参数

  • d_model: 模型维度
  • nhead: 注意力头数
  • dim_feedforward: FFN 隐藏层维度

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 编码器层
layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)

# 输入
x = torch.randn(32, 100, 512)
output = layer(x)

print("输入:", x.shape, "-> 输出:", output.shape)

示例 2: 自定义激活函数

实例

import torch
import torch.nn as nn

# 使用 ReLU 激活
layer_relu = nn.TransformerEncoderLayer(d_model=256, nhead=4, activation='relu', batch_first=True)

x = torch.randn(8, 50, 256)
out = layer_relu(x)

print("ReLU 激活输出:", out.shape)

示例 3: 多层堆叠

实例

import torch
import torch.nn as nn

# 堆叠多个层
layer = nn.TransformerEncoderLayer(256, 4, batch_first=True)
encoder = nn.Sequential(
    layer,
    nn.TransformerEncoderLayer(256, 4, batch_first=True),
    nn.TransformerEncoderLayer(256, 4, batch_first=True)
)

x = torch.randn(4, 30, 256)
out = encoder(x)

print("3层编码器输出:", out.shape)

使用场景

  • 构建编码器
  • BERT 等模型
  • 文本理解

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册