现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.scatter_reduce 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.scatter_reduce 是 PyTorch 中用于将源张量的值按指定方式聚合到指定位置的函数。它支持多种聚合方式,如求和、乘积、最大值、最小值等。

函数定义

torch.scatter_reduce(input, dim, index, src, reduce='sum', *, include_self=True)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 聚合的维度。
  • index (Tensor): 索引张量,指定要将 src 的值聚合到 input 的哪个位置。
  • src (Tensor): 源张量,要聚合的值。
  • reduce (str): 聚合方式,可选值为 'sum'、'prod'、'mean'、'amax'、'amin'、'multiply'。默认为 'sum'。
  • include_self (bool, 可选): 是否在聚合中包含索引位置本身的原始值。默认为 True。

返回值:

  • torch.Tensor: 返回聚合后的张量。

使用示例

实例

import torch

# 创建输入张量
input = torch.ones(3, 5)

# 创建索引和源
index = torch.tensor([[0, 1, 2, 0, 0],
                      [1, 2, 0, 1, 2],
                      [2, 0, 1, 2, 0]])
src = torch.tensor([[1, 1, 1, 1, 1],
                    [2, 2, 2, 2, 2],
                    [3, 3, 3, 3, 3]])

# 使用 sum 聚合
output = torch.scatter_reduce(input, dim=0, index=index, src=src, reduce='sum')

print("输入:")
print(input)
print("n聚合方式: sum")
print("结果:")
print(output)

输出结果为:

输入:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

聚合方式: sum
结果:
tensor([[4., 2., 3., 4., 4.],
        [2., 2., 2., 2., 2.],
        [3., 3., 2., 3., 3.]])

实例

import torch

# 测试不同的聚合方式
input = torch.zeros(3, 3)
index = torch.tensor([[0, 0, 0],
                      [1, 1, 1],
                      [2, 2, 2]])
src = torch.tensor([[2, 3, 4],
                    [5, 6, 7],
                    [8, 9, 10]])

# 使用 prod (乘积)
output_prod = torch.scatter_reduce(input, 0, index, src, reduce='prod')
print("prod 聚合:")
print(output_prod)

# 使用 amax (最大值)
output_max = torch.scatter_reduce(input, 0, index, src, reduce='amax')
print("namax 聚合:")
print(output_max)

# 使用 amin (最小值)
output_min = torch.scatter_reduce(input, 0, index, src, reduce='amin')
print("namin 聚合:")
print(output_min)

输出结果为:

prod 聚合:
tensor([[ 2.,  3.,  4.],
        [ 5.,  6.,  7.],
        [ 8.,  9., 10.]])

amax 聚合:
tensor([[ 2.,  3.,  4.],
        [ 5.,  6.,  7.],
        [ 8.,  9., 10.]])

amin 聚合:
tensor([[ 2.,  3.,  4.],
        [ 5.,  6.,  7.],
        [ 8.,  9., 10.]])

实例

import torch

# 使用 include_self 参数
input = torch.tensor([1.0, 2.0, 3.0])
index = torch.tensor([0, 0, 0])
src = torch.tensor([10.0, 20.0, 30.0])

# 包含自身值 (默认)
output1 = torch.scatter_reduce(input, 0, index, src, reduce='sum', include_self=True)
print("include_self=True:", output1)

# 不包含自身值
output2 = torch.scatter_reduce(input, 0, index, src, reduce='sum', include_self=False)
print("include_self=False:", output2)

输出结果为:

include_self=True: tensor([66.,  2.,  3.])
include_self=False: tensor([60.,  2.,  3.])

注意:torch.scatter_reduce 不会修改原始输入张量,而是返回一个新的张量。多个索引可以指向同一位置,值会按照指定的聚合方式进行聚合。include_self=False 在某些场景下很有用,比如在图神经网络中避免将节点自身的特征重复计算。


Pytorch torch 参考手册 Pytorch torch 参考手册