PyTorch torch.nn.Embedding 函数
torch.nn.Embedding 是 PyTorch 中用于词嵌入的模块。
它将离散的词汇索引映射到连续的向量空间,是自然语言处理中最基础的操作之一。
函数定义
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
参数说明:
num_embeddings(int): 词汇表大小,即嵌入矩阵的行数。embedding_dim(int): 每个嵌入向量的维度。padding_idx(int): 指定填充索引,其嵌入向量为零向量。默认为 None。max_norm(float): 如果非 None,嵌入向量会被规范化到这个范数。默认为 None。norm_type(float): 计算范数的阶数。默认为 2.0。scale_grad_by_freq(bool): 是否根据词频缩放梯度。默认为 False。sparse(bool): 权重矩阵是否为稀疏矩阵。默认为 False。
属性:
weight(Tensor): 形状为 (num_embeddings, embedding_dim) 的可学习权重。
使用示例
示例 1: 基本用法
创建并使用词嵌入:
实例
import torch
import torch.nn as nn
# 创建嵌入层:词汇表 10000,嵌入维度 256
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=256)
# 词索引(从 0 开始)
# 形状: (batch, seq_len)
input_indices = torch.tensor([[12, 45, 678], [901, 23, 56]])
# 查表得到嵌入向量
output = embedding(input_indices)
print("输入索引形状:", input_indices.shape)
print("输出嵌入形状:", output.shape) # (2, 3, 256)
# 查看嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)
import torch.nn as nn
# 创建嵌入层:词汇表 10000,嵌入维度 256
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=256)
# 词索引(从 0 开始)
# 形状: (batch, seq_len)
input_indices = torch.tensor([[12, 45, 678], [901, 23, 56]])
# 查表得到嵌入向量
output = embedding(input_indices)
print("输入索引形状:", input_indices.shape)
print("输出嵌入形状:", output.shape) # (2, 3, 256)
# 查看嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)
示例 2: 使用 padding_idx
指定填充索引:
实例
import torch
import torch.nn as nn
# 创建带 padding 的嵌入层
embedding = nn.Embedding(num_embeddings=1000, embedding_dim=64, padding_idx=0)
# 0 被用作 padding
input_indices = torch.tensor([[1, 2, 3], [4, 0, 0]]) # 第二句有 padding
output = embedding(input_indices)
print("输入形状:", input_indices.shape)
print("输出形状:", output.shape)
print("padding 的嵌入向量:", output[1, 1].tolist()) # 全 0
print("非 padding 的嵌入向量:", output[0, 0].tolist()[:5]) # 非零
import torch.nn as nn
# 创建带 padding 的嵌入层
embedding = nn.Embedding(num_embeddings=1000, embedding_dim=64, padding_idx=0)
# 0 被用作 padding
input_indices = torch.tensor([[1, 2, 3], [4, 0, 0]]) # 第二句有 padding
output = embedding(input_indices)
print("输入形状:", input_indices.shape)
print("输出形状:", output.shape)
print("padding 的嵌入向量:", output[1, 1].tolist()) # 全 0
print("非 padding 的嵌入向量:", output[0, 0].tolist()[:5]) # 非零
示例 3: 预训练词向量
加载预训练词向量:
实例
import torch
import torch.nn as nn
import numpy as np
# 模拟预训练词向量(实际可用 GloVe、Word2Vec 等)
vocab_size = 1000
embedding_dim = 300
# 随机初始化(实际应加载预训练向量)
pretrained_weights = np.random.randn(vocab_size, embedding_dim).astype('float32')
# 创建嵌入层
embedding = nn.Embedding(vocab_size, embedding_dim)
# 加载预训练权重
embedding.weight.data = torch.from_numpy(pretrained_weights)
# 冻结嵌入层(不参与训练)
embedding.weight.requires_grad = False
print("嵌入层可训练:", embedding.weight.requires_grad)
print("嵌入矩阵形状:", embedding.weight.shape)
import torch.nn as nn
import numpy as np
# 模拟预训练词向量(实际可用 GloVe、Word2Vec 等)
vocab_size = 1000
embedding_dim = 300
# 随机初始化(实际应加载预训练向量)
pretrained_weights = np.random.randn(vocab_size, embedding_dim).astype('float32')
# 创建嵌入层
embedding = nn.Embedding(vocab_size, embedding_dim)
# 加载预训练权重
embedding.weight.data = torch.from_numpy(pretrained_weights)
# 冻结嵌入层(不参与训练)
embedding.weight.requires_grad = False
print("嵌入层可训练:", embedding.weight.requires_grad)
print("嵌入矩阵形状:", embedding.weight.shape)
示例 4: 限制嵌入向量范数
使用 max_norm 限制向量范数:
h2 class="example">实例
import torch
import torch.nn as nn
# 限制嵌入向量最大范数为 1.0
embedding = nn.Embedding(1000, 64, max_norm=1.0)
# 输入
input_indices = torch.tensor([1, 2, 3])
# 原始权重范数
original_norm = embedding.weight.data.norm(dim=1)[:3]
print("原始权重范数:", original_norm.tolist())
# 查表后的向量范数
output = embedding(input_indices)
output_norm = output.norm(dim=1)
print("输出向量范数:", output_norm.tolist())
import torch.nn as nn
# 限制嵌入向量最大范数为 1.0
embedding = nn.Embedding(1000, 64, max_norm=1.0)
# 输入
input_indices = torch.tensor([1, 2, 3])
# 原始权重范数
original_norm = embedding.weight.data.norm(dim=1)[:3]
print("原始权重范数:", original_norm.tolist())
# 查表后的向量范数
output = embedding(input_indices)
output_norm = output.norm(dim=1)
print("输出向量范数:", output_norm.tolist())
示例 5: 完整的文本分类模型
使用嵌入层的文本分类:
实例
import torch
import torch.nn as nn
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
super(TextClassifier, self).__init__()
# 词嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
# 分类器
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_dim)
# LSTM 取最后一个输出
_, (hidden, _) = self.lstm(embedded)
hidden = hidden[-1] # 最后一层的隐藏状态
# 分类
logits = self.classifier(hidden)
return logits
# 参数
VOCAB_SIZE = 10000
EMBED_DIM = 128
HIDDEN_DIM = 256
model = TextClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM)
# 输入:batch=4,序列长度=50
input_ids = torch.randint(1, VOCAB_SIZE, (4, 50))
output = model(input_ids)
print("模型结构:")
print(model)
print("n输入形状:", input_ids.shape)
print("输出形状:", output.shape)
import torch.nn as nn
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
super(TextClassifier, self).__init__()
# 词嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# LSTM
self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
# 分类器
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_dim)
# LSTM 取最后一个输出
_, (hidden, _) = self.lstm(embedded)
hidden = hidden[-1] # 最后一层的隐藏状态
# 分类
logits = self.classifier(hidden)
return logits
# 参数
VOCAB_SIZE = 10000
EMBED_DIM = 128
HIDDEN_DIM = 256
model = TextClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM)
# 输入:batch=4,序列长度=50
input_ids = torch.randint(1, VOCAB_SIZE, (4, 50))
output = model(input_ids)
print("模型结构:")
print(model)
print("n输入形状:", input_ids.shape)
print("输出形状:", output.shape)
Embedding 与 EmbeddingBag 的区别
| 类型 | 输入 | 输出 | 适用场景 |
|---|---|---|---|
nn.Embedding |
词索引 | 词向量序列 | 序列模型、LSTM、Transformer |
nn.EmbeddingBag |
词索引 + 偏移 | 聚合后的向量 | 文本分类、快速处理 |
常见问题
Q1: 嵌入维度如何选择?
- 小数据集:50-100 维
- 中等数据集:100-300 维
- 大数据集:300-500 维
Q2: padding_idx 的作用是什么?
将指定索引的嵌入向量设为零向量,并在反向传播时不计算其梯度。
Q3: 何时冻结嵌入层?
使用预训练词向量时,通常先冻结训练一段时间后再微调。
使用场景
nn.Embedding 主要应用场景包括:
- 词向量表示: 将单词转换为密集向量
- 文本分类: 作为 NLP 模型的输入层
- 序列模型: LSTM、GRU 的输入
- 推荐系统: 用户和物品的嵌入表示
注意:embedding.weight 是一个可学习的参数,可以直接在优化器中训练。

PyTorch torch.nn 参考手册