现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.TransformerEncoder 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.TransformerEncoder 是 PyTorch 中的 Transformer 编码器模块。

它由多个 TransformerEncoderLayer 组成,用于处理输入序列。

函数定义

torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None)

参数

  • encoder_layer: 编码器层
  • num_layers: 层数

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 单个编码器层
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)

# 完整编码器
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

# 输入
x = torch.randn(32, 100, 512)  # batch, seq, d_model
output = transformer_encoder(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 2: 文本分类

实例

import torch
import torch.nn as nn

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = nn.Parameter(torch.randn(1, 512, d_model) * 0.1)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(d_model, 10)

    def forward(self, x):
        # 简单位置编码
        x = self.embedding(x) + self.pos_encoder[:, :x.size(1), :]
        x = self.encoder(x)
        # 取第一个 token
        return self.fc(x[:, 0, :])

model = TransformerClassifier(10000)
x = torch.randint(0, 10000, (32, 100))
output = model(x)

print("输出形状:", output.shape)

使用场景

  • BERT: 编码器基础
  • 文本分类
  • 特征提取

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册