PyTorch torch.nn.EmbeddingBag 函数
torch.nn.EmbeddingBag 是 PyTorch 中的嵌入袋模块。
它将多个嵌入向量聚合成单个向量,常用于文本分类和快速处理。
函数定义
torch.nn.EmbeddingBag(num_embeddings, embedding_dim, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, mode='mean', sparse=False, include_last_offset=False)
参数:
mode: 聚合方式,可选'mean'、'sum'、'max'
使用示例
示例 1: 基本用法
实例
import torch
import torch.nn as nn
# 嵌入袋
ebag = nn.EmbeddingBag(1000, 128, mode='mean')
# 词索引
indices = torch.tensor([[1, 2, 3], [4, 5]])
# 偏移量(表示每句话的边界)
offsets = torch.tensor([0, 3])
output = ebag(indices, offsets)
print("输入形状:", indices.shape)
print("输出形状:", output.shape)
import torch.nn as nn
# 嵌入袋
ebag = nn.EmbeddingBag(1000, 128, mode='mean')
# 词索引
indices = torch.tensor([[1, 2, 3], [4, 5]])
# 偏移量(表示每句话的边界)
offsets = torch.tensor([0, 3])
output = ebag(indices, offsets)
print("输入形状:", indices.shape)
print("输出形状:", output.shape)
示例 2: 快速文本分类
实例
import torch
import torch.nn as nn
class FastText(nn.Module):
def __init__(self, vocab_size, embed_dim, num_classes):
super(FastText, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, text, offset):
x = self.embedding(text, offset)
return self.fc(x)
model = FastText(10000, 128, 2)
# 模拟数据
text = torch.randint(0, 10000, (100,))
offset = torch.tensor([0, 30, 60, 100])
output = model(text, offset)
print("输出形状:", output.shape)
import torch.nn as nn
class FastText(nn.Module):
def __init__(self, vocab_size, embed_dim, num_classes):
super(FastText, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, mode='mean')
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, text, offset):
x = self.embedding(text, offset)
return self.fc(x)
model = FastText(10000, 128, 2)
# 模拟数据
text = torch.randint(0, 10000, (100,))
offset = torch.tensor([0, 30, 60, 100])
output = model(text, offset)
print("输出形状:", output.shape)
示例 3: 不同聚合方式
实例
import torch
import torch.nn as nn
for mode in ['mean', 'sum', 'max']:
ebag = nn.EmbeddingBag(100, 32, mode=mode)
indices = torch.tensor([1, 2, 3, 4])
offsets = torch.tensor([0, 2])
out = ebag(indices, offsets)
print(f"{mode}:", out.shape)
import torch.nn as nn
for mode in ['mean', 'sum', 'max']:
ebag = nn.EmbeddingBag(100, 32, mode=mode)
indices = torch.tensor([1, 2, 3, 4])
offsets = torch.tensor([0, 2])
out = ebag(indices, offsets)
print(f"{mode}:", out.shape)
使用场景
- 文本分类: FastText
- 词袋模型: 快速聚合
- 大规模文本: 高效处理
注意:需要提供 offsets 参数指明句子边界。

PyTorch torch.nn 参考手册