现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.index_add 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.index_add 是 PyTorch 中用于将源张量的值加到指定索引位置的函数。它沿指定维度 dim,在 index 指定的索引位置添加 source 的值。

函数定义

torch.index_add(input, dim, index, source, *, alpha=1)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 索引的维度。
  • index (Tensor): 一维整数张量,指定要添加到的位置。
  • source (Tensor): 源张量,要添加的值。
  • alpha (float, 可选): source 的缩放因子,默认为 1。

返回值:

  • torch.Tensor: 返回修改后的张量。

使用示例

实例

import torch

# 创建输入张量
input = torch.randn(4, 5)

# 创建索引和源
index = torch.tensor([0, 2, 3])
source = torch.randn(3, 5)

# 沿 dim=0 添加
output = torch.index_add(input, dim=0, index=index, source=source)

print("输入形状:", input.shape)
print("索引:", index)
print("源形状:", source.shape)
print("结果形状:", output.shape)
print("n结果:")
print(output)

输出结果为:

输入形状: torch.Size([4, 5])
索引: tensor([0, 2, 3])
源形状: torch.Size([3, 5])
结果形状: torch.Size([4, 5])

结果:
tensor([[ 1.8435,  0.3463, -0.1024,  0.5678,  0.1234],
        [-0.5678,  0.8901, -0.2345,  0.6789, -0.1234],
        [ 2.3456,  0.4567,  0.7890, -0.3456,  0.5678],
        [-0.7890,  1.2345,  0.3456, -0.8901,  0.2345]])

实例

import torch

# 使用 alpha 参数缩放源
input = torch.zeros(5)
index = torch.tensor([0, 2, 4])
source = torch.tensor([10, 20, 30])

# alpha=2 表示将 source * 2 后再添加
output = torch.index_add(input, dim=0, index=index, source=source, alpha=2)

print("输入:", input)
print("源:", source)
print("alpha=2 后的结果:", output)

输出结果为:

输入: tensor([0., 0., 0., 0., 0.])
源: tensor([10., 20., 30.])
alpha=2 后的结果: tensor([20.,  0., 40.,  0., 60.])

实例

import torch

# 沿其他维度添加
input = torch.zeros(3, 4, 5)
index = torch.tensor([1, 3])
source = torch.randn(2, 4, 5)

# 沿 dim=1 添加
output = torch.index_add(input, dim=1, index=index, source=source)

print("输入形状:", input.shape)
print("索引形状:", index.shape)
print("源形状:", source.shape)
print("结果形状:", output.shape)

输出结果为:

输入形状: torch.Size([3, 4, 5])
索引形状: torch.Size([2])
源形状: torch.Size([2, 4, 5])
结果形状: torch.Size([3, 4, 5])

实例

import torch

# 在神经网络中的应用:注意力机制
# 假设有多个key-value对,需要聚合到query上

# 模拟query和key-value
num_queries = 2
num_kv = 4
dim = 3

# query的索引
query_idx = torch.tensor([0, 1])
# 对应的value
values = torch.randn(num_queries, dim) * 10

# 输出
output = torch.zeros(num_kv, dim)
# 将value加到对应位置
output = torch.index_add(output, dim=0, index=query_idx, source=values)

print("Query索引:", query_idx)
print("Values:", values)
print("聚合结果:", output)

注意:torch.index_add 不会修改原始输入张量,而是返回一个新的张量。如果 index 中有重复的索引,值会被累加。alpha 参数可以用于对源值进行缩放。


Pytorch torch 参考手册 Pytorch torch 参考手册