PyTorch torch.nn.TransformerEncoder 函数
torch.nn.TransformerEncoder 是 PyTorch 中的 Transformer 编码器模块。
它由多个 TransformerEncoderLayer 组成,用于处理输入序列。
函数定义
torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None)
参数
encoder_layer: 编码器层num_layers: 层数
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 单个编码器层
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
# 完整编码器
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
# 输入
x = torch.randn(32, 100, 512) # batch, seq, d_model
output = transformer_encoder(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
import torch.nn as nn
# 单个编码器层
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
# 完整编码器
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
# 输入
x = torch.randn(32, 100, 512) # batch, seq, d_model
output = transformer_encoder(x)
print("输入形状:", x.shape)
print("输出形状:", output.shape)
示例 2: 文本分类
实例
import torch
import torch.nn as nn
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4):
super(TransformerClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = nn.Parameter(torch.randn(1, 512, d_model) * 0.1)
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
self.fc = nn.Linear(d_model, 10)
def forward(self, x):
# 简单位置编码
x = self.embedding(x) + self.pos_encoder[:, :x.size(1), :]
x = self.encoder(x)
# 取第一个 token
return self.fc(x[:, 0, :])
model = TransformerClassifier(10000)
x = torch.randint(0, 10000, (32, 100))
output = model(x)
print("输出形状:", output.shape)
import torch.nn as nn
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4):
super(TransformerClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = nn.Parameter(torch.randn(1, 512, d_model) * 0.1)
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
self.fc = nn.Linear(d_model, 10)
def forward(self, x):
# 简单位置编码
x = self.embedding(x) + self.pos_encoder[:, :x.size(1), :]
x = self.encoder(x)
# 取第一个 token
return self.fc(x[:, 0, :])
model = TransformerClassifier(10000)
x = torch.randint(0, 10000, (32, 100))
output = model(x)
print("输出形状:", output.shape)
使用场景
- BERT: 编码器基础
- 文本分类
- 特征提取

PyTorch torch.nn 参考手册