现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.EmbeddingBag 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.EmbeddingBag 是 PyTorch 中的嵌入袋模块。

它将多个嵌入向量聚合成单个向量,常用于文本分类和快速处理。

函数定义

torch.nn.EmbeddingBag(num_embeddings, embedding_dim, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, mode='mean', sparse=False, include_last_offset=False)

参数:

  • mode: 聚合方式,可选 'mean''sum''max'

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 嵌入袋
ebag = nn.EmbeddingBag(1000, 128, mode='mean')

# 词索引
indices = torch.tensor([[1, 2, 3], [4, 5]])

# 偏移量(表示每句话的边界)
offsets = torch.tensor([0, 3])

output = ebag(indices, offsets)

print("输入形状:", indices.shape)
print("输出形状:", output.shape)

示例 2: 快速文本分类

实例

import torch
import torch.nn as nn

class FastText(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super(FastText, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, text, offset):
        x = self.embedding(text, offset)
        return self.fc(x)

model = FastText(10000, 128, 2)

# 模拟数据
text = torch.randint(0, 10000, (100,))
offset = torch.tensor([0, 30, 60, 100])

output = model(text, offset)
print("输出形状:", output.shape)

示例 3: 不同聚合方式

实例

import torch
import torch.nn as nn

for mode in ['mean', 'sum', 'max']:
    ebag = nn.EmbeddingBag(100, 32, mode=mode)
    indices = torch.tensor([1, 2, 3, 4])
    offsets = torch.tensor([0, 2])
    out = ebag(indices, offsets)
    print(f"{mode}:", out.shape)

使用场景

  • 文本分类: FastText
  • 词袋模型: 快速聚合
  • 大规模文本: 高效处理

注意:需要提供 offsets 参数指明句子边界。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册