PyTorch torch.nn.TransformerEncoderLayer 函数
nn.TransformerEncoderLayer 是 Transformer 编码器的单层结构。
它包含自注意力和前馈网络,是构成完整编码器的基本单元。
函数定义
torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True)
参数
d_model: 模型维度nhead: 注意力头数dim_feedforward: FFN 隐藏层维度
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 编码器层
layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
# 输入
x = torch.randn(32, 100, 512)
output = layer(x)
print("输入:", x.shape, "-> 输出:", output.shape)
import torch.nn as nn
# 编码器层
layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
# 输入
x = torch.randn(32, 100, 512)
output = layer(x)
print("输入:", x.shape, "-> 输出:", output.shape)
示例 2: 自定义激活函数
实例
import torch
import torch.nn as nn
# 使用 ReLU 激活
layer_relu = nn.TransformerEncoderLayer(d_model=256, nhead=4, activation='relu', batch_first=True)
x = torch.randn(8, 50, 256)
out = layer_relu(x)
print("ReLU 激活输出:", out.shape)
import torch.nn as nn
# 使用 ReLU 激活
layer_relu = nn.TransformerEncoderLayer(d_model=256, nhead=4, activation='relu', batch_first=True)
x = torch.randn(8, 50, 256)
out = layer_relu(x)
print("ReLU 激活输出:", out.shape)
示例 3: 多层堆叠
实例
import torch
import torch.nn as nn
# 堆叠多个层
layer = nn.TransformerEncoderLayer(256, 4, batch_first=True)
encoder = nn.Sequential(
layer,
nn.TransformerEncoderLayer(256, 4, batch_first=True),
nn.TransformerEncoderLayer(256, 4, batch_first=True)
)
x = torch.randn(4, 30, 256)
out = encoder(x)
print("3层编码器输出:", out.shape)
import torch.nn as nn
# 堆叠多个层
layer = nn.TransformerEncoderLayer(256, 4, batch_first=True)
encoder = nn.Sequential(
layer,
nn.TransformerEncoderLayer(256, 4, batch_first=True),
nn.TransformerEncoderLayer(256, 4, batch_first=True)
)
x = torch.randn(4, 30, 256)
out = encoder(x)
print("3层编码器输出:", out.shape)
使用场景
- 构建编码器
- BERT 等模型
- 文本理解

PyTorch torch.nn 参考手册