PyTorch torch.nn.Transformer 函数
torch.nn.Transformer 是 PyTorch 中的完整 Transformer 模型。
它包含编码器和解码器,可用于序列到序列的任务。
函数定义
torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation='gelu', batch_first=True)
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 创建 Transformer
transformer = nn.Transformer(d_model=512, nhead=8, batch_first=True)
# 编码器输入
src = torch.randn(10, 32, 512) # (seq, batch, d_model)
# 解码器输入
tgt = torch.randn(20, 32, 512)
output = transformer(src, tgt)
print("输出形状:", output.shape)
import torch.nn as nn
# 创建 Transformer
transformer = nn.Transformer(d_model=512, nhead=8, batch_first=True)
# 编码器输入
src = torch.randn(10, 32, 512) # (seq, batch, d_model)
# 解码器输入
tgt = torch.randn(20, 32, 512)
output = transformer(src, tgt)
print("输出形状:", output.shape)
示例 2: 简单翻译模型
实例
import torch
import torch.nn as nn
class TransformerMT(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4):
super(TransformerMT, self).__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.transformer = nn.Transformer(d_model, nhead, batch_first=True)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
src = self.embedding(src) * (self.d_model ** 0.5)
tgt = self.embedding(tgt) * (self.d_model ** 0.5)
out = self.transformer(src, tgt)
return self.fc(out)
model = TransformerMT(10000)
src = torch.randint(0, 10000, (32, 50))
tgt = torch.randint(0, 10000, (32, 40))
output = model(src, tgt)
print("输出形状:", output.shape)
import torch.nn as nn
class TransformerMT(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4):
super(TransformerMT, self).__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.transformer = nn.Transformer(d_model, nhead, batch_first=True)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src, tgt):
src = self.embedding(src) * (self.d_model ** 0.5)
tgt = self.embedding(tgt) * (self.d_model ** 0.5)
out = self.transformer(src, tgt)
return self.fc(out)
model = TransformerMT(10000)
src = torch.randint(0, 10000, (32, 50))
tgt = torch.randint(0, 10000, (32, 40))
output = model(src, tgt)
print("输出形状:", output.shape)
使用场景
- 机器翻译
- 文本生成
- 序列到序列
提示:batch_first=True 时,输入形状为 (batch, seq, d_model)。

PyTorch torch.nn 参考手册