现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.scatter_add 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.scatter_add 是 PyTorch 中用于将源张量的值加到指定位置的函数。它将 src 的值按照 index 指定的位置加到 input 中。

函数定义

torch.scatter_add(input, dim, index, src)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 散布的维度。
  • index (Tensor): 索引张量,指定要将 src 的值加到 input 的哪个位置。
  • src (Tensor): 源张量,要添加的值。

返回值:

  • torch.Tensor: 返回修改后的张量。

使用示例

实例

import torch

# 创建输入张量
input = torch.zeros(3, 5)

# 创建索引和源
index = torch.tensor([[0, 1, 2, 0, 0],
                      [1, 2, 0, 1, 2],
                      [2, 0, 1, 2, 0]])
src = torch.tensor([[1, 1, 1, 1, 1],
                    [2, 2, 2, 2, 2],
                    [3, 3, 3, 3, 3]])

# 沿 dim=0 散布并累加
output = torch.scatter_add(input, dim=0, index=index, src=src)

print("输入:")
print(input)
print("n索引:")
print(index)
print("n源:")
print(src)
print("n结果:")
print(output)

输出结果为:

输入:
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

索引:
tensor([[0, 1, 2, 0, 0],
        [1, 2, 0, 1, 2],
        [2, 0, 1, 2, 0]])

源:
tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]])

结果:
tensor([[4., 1., 2., 4., 4.],
        [2., 1., 2., 2., 2.],
        [3., 3., 1., 3., 3.]])

实例

import torch

# 使用dim=1
input = torch.zeros(3, 5)
index = torch.tensor([[0, 1, 2, 1, 0],
                      [1, 2, 0, 2, 1],
                      [0, 1, 1, 0, 2]])
src = torch.arange(1, 6).float()

output = torch.scatter_add(input, dim=1, index=index, src=src)

print("沿 dim=1 散布:")
print(output)

输出结果为:

沿 dim=1 散布:
tensor([[ 6.,  3.,  3.,  0.,  0.],
        [ 3.,  6.,  3.,  0.,  0.],
        [ 2.,  4.,  5.,  0.,  0.]])

实例

import torch

# 聚合多个位置的值的应用场景
# 比如在图神经网络中累加邻居节点的特征

# 模拟4个节点的初始特征
node_features = torch.zeros(4, 3)

# 模拟边的连接关系(源节点指向目标节点)
edge_index = torch.tensor([0, 1, 2, 3, 0, 1])  # 边的源节点
edge_weights = torch.tensor([1.0, 2.0, 3.0, 1.5, 2.5, 0.5])

# 为每条边创建源节点特征的加权值
src_features = torch.randn(6, 3) * edge_weights.unsqueeze(1)

# 将特征累加到目标节点(这里简化处理,实际需要根据边的目标节点)
target_nodes = torch.tensor([0, 0, 1, 1, 2, 3])
index = target_nodes

output = torch.scatter_add(node_features, 0, index.unsqueeze(1).expand_as(src_features), src_features)

print("节点特征形状:", node_features.shape)
print("累加后的特征:", output)

注意:torch.scatter_add 不会修改原始输入张量,而是返回一个新的张量。多个索引可以指向同一位置,值会被累加。该函数是 torch.gather 的逆操作。


Pytorch torch 参考手册 Pytorch torch 参考手册