现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Embedding 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Embedding 是 PyTorch 中用于词嵌入的模块。

它将离散的词汇索引映射到连续的向量空间,是自然语言处理中最基础的操作之一。

函数定义

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)

参数说明:

  • num_embeddings (int): 词汇表大小,即嵌入矩阵的行数。
  • embedding_dim (int): 每个嵌入向量的维度。
  • padding_idx (int): 指定填充索引,其嵌入向量为零向量。默认为 None。
  • max_norm (float): 如果非 None,嵌入向量会被规范化到这个范数。默认为 None。
  • norm_type (float): 计算范数的阶数。默认为 2.0。
  • scale_grad_by_freq (bool): 是否根据词频缩放梯度。默认为 False。
  • sparse (bool): 权重矩阵是否为稀疏矩阵。默认为 False。

属性:

  • weight (Tensor): 形状为 (num_embeddings, embedding_dim) 的可学习权重。

使用示例

示例 1: 基本用法

创建并使用词嵌入:

实例

import torch
import torch.nn as nn

# 创建嵌入层:词汇表 10000,嵌入维度 256
embedding = nn.Embedding(num_embeddings=10000, embedding_dim=256)

# 词索引(从 0 开始)
# 形状: (batch, seq_len)
input_indices = torch.tensor([[12, 45, 678], [901, 23, 56]])

# 查表得到嵌入向量
output = embedding(input_indices)

print("输入索引形状:", input_indices.shape)
print("输出嵌入形状:", output.shape)  # (2, 3, 256)

# 查看嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)

示例 2: 使用 padding_idx

指定填充索引:

实例

import torch
import torch.nn as nn

# 创建带 padding 的嵌入层
embedding = nn.Embedding(num_embeddings=1000, embedding_dim=64, padding_idx=0)

# 0 被用作 padding
input_indices = torch.tensor([[1, 2, 3], [4, 0, 0]])  # 第二句有 padding

output = embedding(input_indices)

print("输入形状:", input_indices.shape)
print("输出形状:", output.shape)
print("padding 的嵌入向量:", output[1, 1].tolist())  # 全 0
print("非 padding 的嵌入向量:", output[0, 0].tolist()[:5])  # 非零

示例 3: 预训练词向量

加载预训练词向量:

实例

import torch
import torch.nn as nn
import numpy as np

# 模拟预训练词向量(实际可用 GloVe、Word2Vec 等)
vocab_size = 1000
embedding_dim = 300

# 随机初始化(实际应加载预训练向量)
pretrained_weights = np.random.randn(vocab_size, embedding_dim).astype('float32')

# 创建嵌入层
embedding = nn.Embedding(vocab_size, embedding_dim)

# 加载预训练权重
embedding.weight.data = torch.from_numpy(pretrained_weights)

# 冻结嵌入层(不参与训练)
embedding.weight.requires_grad = False

print("嵌入层可训练:", embedding.weight.requires_grad)
print("嵌入矩阵形状:", embedding.weight.shape)

示例 4: 限制嵌入向量范数

使用 max_norm 限制向量范数:

h2 class="example">实例
import torch
import torch.nn as nn

# 限制嵌入向量最大范数为 1.0
embedding = nn.Embedding(1000, 64, max_norm=1.0)

# 输入
input_indices = torch.tensor([1, 2, 3])

# 原始权重范数
original_norm = embedding.weight.data.norm(dim=1)[:3]
print("原始权重范数:", original_norm.tolist())

# 查表后的向量范数
output = embedding(input_indices)
output_norm = output.norm(dim=1)
print("输出向量范数:", output_norm.tolist())

示例 5: 完整的文本分类模型

使用嵌入层的文本分类:

实例

import torch
import torch.nn as nn

class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
        super(TextClassifier, self).__init__()
        # 词嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # LSTM
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        # 分类器
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)

        # LSTM 取最后一个输出
        _, (hidden, _) = self.lstm(embedded)
        hidden = hidden[-1]  # 最后一层的隐藏状态

        # 分类
        logits = self.classifier(hidden)
        return logits

# 参数
VOCAB_SIZE = 10000
EMBED_DIM = 128
HIDDEN_DIM = 256

model = TextClassifier(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM)

# 输入:batch=4,序列长度=50
input_ids = torch.randint(1, VOCAB_SIZE, (4, 50))
output = model(input_ids)

print("模型结构:")
print(model)
print("n输入形状:", input_ids.shape)
print("输出形状:", output.shape)

Embedding 与 EmbeddingBag 的区别

类型 输入 输出 适用场景
nn.Embedding 词索引 词向量序列 序列模型、LSTM、Transformer
nn.EmbeddingBag 词索引 + 偏移 聚合后的向量 文本分类、快速处理

常见问题

Q1: 嵌入维度如何选择?

  • 小数据集:50-100 维
  • 中等数据集:100-300 维
  • 大数据集:300-500 维

Q2: padding_idx 的作用是什么?

将指定索引的嵌入向量设为零向量,并在反向传播时不计算其梯度。

Q3: 何时冻结嵌入层?

使用预训练词向量时,通常先冻结训练一段时间后再微调。


使用场景

nn.Embedding 主要应用场景包括:

  • 词向量表示: 将单词转换为密集向量
  • 文本分类: 作为 NLP 模型的输入层
  • 序列模型: LSTM、GRU 的输入
  • 推荐系统: 用户和物品的嵌入表示

注意:embedding.weight 是一个可学习的参数,可以直接在优化器中训练。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册