PyTorch torch.scatter_reduce 函数
torch.scatter_reduce 是 PyTorch 中用于将源张量的值按指定方式聚合到指定位置的函数。它支持多种聚合方式,如求和、乘积、最大值、最小值等。
函数定义
torch.scatter_reduce(input, dim, index, src, reduce='sum', *, include_self=True)
参数:
input(Tensor): 输入张量。dim(int): 聚合的维度。index(Tensor): 索引张量,指定要将 src 的值聚合到 input 的哪个位置。src(Tensor): 源张量,要聚合的值。reduce(str): 聚合方式,可选值为 'sum'、'prod'、'mean'、'amax'、'amin'、'multiply'。默认为 'sum'。include_self(bool, 可选): 是否在聚合中包含索引位置本身的原始值。默认为 True。
返回值:
torch.Tensor: 返回聚合后的张量。
使用示例
实例
import torch
# 创建输入张量
input = torch.ones(3, 5)
# 创建索引和源
index = torch.tensor([[0, 1, 2, 0, 0],
[1, 2, 0, 1, 2],
[2, 0, 1, 2, 0]])
src = torch.tensor([[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3]])
# 使用 sum 聚合
output = torch.scatter_reduce(input, dim=0, index=index, src=src, reduce='sum')
print("输入:")
print(input)
print("n聚合方式: sum")
print("结果:")
print(output)
# 创建输入张量
input = torch.ones(3, 5)
# 创建索引和源
index = torch.tensor([[0, 1, 2, 0, 0],
[1, 2, 0, 1, 2],
[2, 0, 1, 2, 0]])
src = torch.tensor([[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3]])
# 使用 sum 聚合
output = torch.scatter_reduce(input, dim=0, index=index, src=src, reduce='sum')
print("输入:")
print(input)
print("n聚合方式: sum")
print("结果:")
print(output)
输出结果为:
输入:
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
聚合方式: sum
结果:
tensor([[4., 2., 3., 4., 4.],
[2., 2., 2., 2., 2.],
[3., 3., 2., 3., 3.]])
实例
import torch
# 测试不同的聚合方式
input = torch.zeros(3, 3)
index = torch.tensor([[0, 0, 0],
[1, 1, 1],
[2, 2, 2]])
src = torch.tensor([[2, 3, 4],
[5, 6, 7],
[8, 9, 10]])
# 使用 prod (乘积)
output_prod = torch.scatter_reduce(input, 0, index, src, reduce='prod')
print("prod 聚合:")
print(output_prod)
# 使用 amax (最大值)
output_max = torch.scatter_reduce(input, 0, index, src, reduce='amax')
print("namax 聚合:")
print(output_max)
# 使用 amin (最小值)
output_min = torch.scatter_reduce(input, 0, index, src, reduce='amin')
print("namin 聚合:")
print(output_min)
# 测试不同的聚合方式
input = torch.zeros(3, 3)
index = torch.tensor([[0, 0, 0],
[1, 1, 1],
[2, 2, 2]])
src = torch.tensor([[2, 3, 4],
[5, 6, 7],
[8, 9, 10]])
# 使用 prod (乘积)
output_prod = torch.scatter_reduce(input, 0, index, src, reduce='prod')
print("prod 聚合:")
print(output_prod)
# 使用 amax (最大值)
output_max = torch.scatter_reduce(input, 0, index, src, reduce='amax')
print("namax 聚合:")
print(output_max)
# 使用 amin (最小值)
output_min = torch.scatter_reduce(input, 0, index, src, reduce='amin')
print("namin 聚合:")
print(output_min)
输出结果为:
prod 聚合:
tensor([[ 2., 3., 4.],
[ 5., 6., 7.],
[ 8., 9., 10.]])
amax 聚合:
tensor([[ 2., 3., 4.],
[ 5., 6., 7.],
[ 8., 9., 10.]])
amin 聚合:
tensor([[ 2., 3., 4.],
[ 5., 6., 7.],
[ 8., 9., 10.]])
实例
import torch
# 使用 include_self 参数
input = torch.tensor([1.0, 2.0, 3.0])
index = torch.tensor([0, 0, 0])
src = torch.tensor([10.0, 20.0, 30.0])
# 包含自身值 (默认)
output1 = torch.scatter_reduce(input, 0, index, src, reduce='sum', include_self=True)
print("include_self=True:", output1)
# 不包含自身值
output2 = torch.scatter_reduce(input, 0, index, src, reduce='sum', include_self=False)
print("include_self=False:", output2)
# 使用 include_self 参数
input = torch.tensor([1.0, 2.0, 3.0])
index = torch.tensor([0, 0, 0])
src = torch.tensor([10.0, 20.0, 30.0])
# 包含自身值 (默认)
output1 = torch.scatter_reduce(input, 0, index, src, reduce='sum', include_self=True)
print("include_self=True:", output1)
# 不包含自身值
output2 = torch.scatter_reduce(input, 0, index, src, reduce='sum', include_self=False)
print("include_self=False:", output2)
输出结果为:
include_self=True: tensor([66., 2., 3.]) include_self=False: tensor([60., 2., 3.])
注意:torch.scatter_reduce 不会修改原始输入张量,而是返回一个新的张量。多个索引可以指向同一位置,值会按照指定的聚合方式进行聚合。include_self=False 在某些场景下很有用,比如在图神经网络中避免将节点自身的特征重复计算。

Pytorch torch 参考手册