现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Transformer 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Transformer 是 PyTorch 中的完整 Transformer 模型。

它包含编码器和解码器,可用于序列到序列的任务。

函数定义

torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True)

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 创建 Transformer
transformer = nn.Transformer(d_model=512, nhead=8, batch_first=True)

# 编码器输入
src = torch.randn(10, 32, 512)  # (seq, batch, d_model)
# 解码器输入
tgt = torch.randn(20, 32, 512)

output = transformer(src, tgt)

print("输出形状:", output.shape)

示例 2: 简单翻译模型

实例

import torch
import torch.nn as nn

class TransformerMT(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4):
        super(TransformerMT, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(d_model, nhead, batch_first=True)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.embedding(src) * (self.d_model ** 0.5)
        tgt = self.embedding(tgt) * (self.d_model ** 0.5)
        out = self.transformer(src, tgt)
        return self.fc(out)

model = TransformerMT(10000)
src = torch.randint(0, 10000, (32, 50))
tgt = torch.randint(0, 10000, (32, 40))
output = model(src, tgt)

print("输出形状:", output.shape)

使用场景

  • 机器翻译
  • 文本生成
  • 序列到序列

提示:batch_first=True 时,输入形状为 (batch, seq, d_model)。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册