PyTorch torch.index_add 函数
torch.index_add 是 PyTorch 中用于将源张量的值加到指定索引位置的函数。它沿指定维度 dim,在 index 指定的索引位置添加 source 的值。
函数定义
torch.index_add(input, dim, index, source, *, alpha=1)
参数:
input(Tensor): 输入张量。dim(int): 索引的维度。index(Tensor): 一维整数张量,指定要添加到的位置。source(Tensor): 源张量,要添加的值。alpha(float, 可选): source 的缩放因子,默认为 1。
返回值:
torch.Tensor: 返回修改后的张量。
使用示例
实例
import torch
# 创建输入张量
input = torch.randn(4, 5)
# 创建索引和源
index = torch.tensor([0, 2, 3])
source = torch.randn(3, 5)
# 沿 dim=0 添加
output = torch.index_add(input, dim=0, index=index, source=source)
print("输入形状:", input.shape)
print("索引:", index)
print("源形状:", source.shape)
print("结果形状:", output.shape)
print("n结果:")
print(output)
# 创建输入张量
input = torch.randn(4, 5)
# 创建索引和源
index = torch.tensor([0, 2, 3])
source = torch.randn(3, 5)
# 沿 dim=0 添加
output = torch.index_add(input, dim=0, index=index, source=source)
print("输入形状:", input.shape)
print("索引:", index)
print("源形状:", source.shape)
print("结果形状:", output.shape)
print("n结果:")
print(output)
输出结果为:
输入形状: torch.Size([4, 5])
索引: tensor([0, 2, 3])
源形状: torch.Size([3, 5])
结果形状: torch.Size([4, 5])
结果:
tensor([[ 1.8435, 0.3463, -0.1024, 0.5678, 0.1234],
[-0.5678, 0.8901, -0.2345, 0.6789, -0.1234],
[ 2.3456, 0.4567, 0.7890, -0.3456, 0.5678],
[-0.7890, 1.2345, 0.3456, -0.8901, 0.2345]])
实例
import torch
# 使用 alpha 参数缩放源
input = torch.zeros(5)
index = torch.tensor([0, 2, 4])
source = torch.tensor([10, 20, 30])
# alpha=2 表示将 source * 2 后再添加
output = torch.index_add(input, dim=0, index=index, source=source, alpha=2)
print("输入:", input)
print("源:", source)
print("alpha=2 后的结果:", output)
# 使用 alpha 参数缩放源
input = torch.zeros(5)
index = torch.tensor([0, 2, 4])
source = torch.tensor([10, 20, 30])
# alpha=2 表示将 source * 2 后再添加
output = torch.index_add(input, dim=0, index=index, source=source, alpha=2)
print("输入:", input)
print("源:", source)
print("alpha=2 后的结果:", output)
输出结果为:
输入: tensor([0., 0., 0., 0., 0.]) 源: tensor([10., 20., 30.]) alpha=2 后的结果: tensor([20., 0., 40., 0., 60.])
实例
import torch
# 沿其他维度添加
input = torch.zeros(3, 4, 5)
index = torch.tensor([1, 3])
source = torch.randn(2, 4, 5)
# 沿 dim=1 添加
output = torch.index_add(input, dim=1, index=index, source=source)
print("输入形状:", input.shape)
print("索引形状:", index.shape)
print("源形状:", source.shape)
print("结果形状:", output.shape)
# 沿其他维度添加
input = torch.zeros(3, 4, 5)
index = torch.tensor([1, 3])
source = torch.randn(2, 4, 5)
# 沿 dim=1 添加
output = torch.index_add(input, dim=1, index=index, source=source)
print("输入形状:", input.shape)
print("索引形状:", index.shape)
print("源形状:", source.shape)
print("结果形状:", output.shape)
输出结果为:
输入形状: torch.Size([3, 4, 5]) 索引形状: torch.Size([2]) 源形状: torch.Size([2, 4, 5]) 结果形状: torch.Size([3, 4, 5])
实例
import torch
# 在神经网络中的应用:注意力机制
# 假设有多个key-value对,需要聚合到query上
# 模拟query和key-value
num_queries = 2
num_kv = 4
dim = 3
# query的索引
query_idx = torch.tensor([0, 1])
# 对应的value
values = torch.randn(num_queries, dim) * 10
# 输出
output = torch.zeros(num_kv, dim)
# 将value加到对应位置
output = torch.index_add(output, dim=0, index=query_idx, source=values)
print("Query索引:", query_idx)
print("Values:", values)
print("聚合结果:", output)
# 在神经网络中的应用:注意力机制
# 假设有多个key-value对,需要聚合到query上
# 模拟query和key-value
num_queries = 2
num_kv = 4
dim = 3
# query的索引
query_idx = torch.tensor([0, 1])
# 对应的value
values = torch.randn(num_queries, dim) * 10
# 输出
output = torch.zeros(num_kv, dim)
# 将value加到对应位置
output = torch.index_add(output, dim=0, index=query_idx, source=values)
print("Query索引:", query_idx)
print("Values:", values)
print("聚合结果:", output)
注意:torch.index_add 不会修改原始输入张量,而是返回一个新的张量。如果 index 中有重复的索引,值会被累加。alpha 参数可以用于对源值进行缩放。

Pytorch torch 参考手册