PyTorch torch.index_reduce 函数
torch.index_reduce 是 PyTorch 中用于将源张量的值按指定方式聚合到指定索引位置的函数。它沿指定维度 dim,在 index 指定的索引位置按指定方式聚合 source 的值。
函数定义
torch.index_reduce(input, dim, index, source, reduce='mean', *, include_self=True)
参数:
input(Tensor): 输入张量。dim(int): 索引的维度。index(Tensor): 一维整数张量,指定要聚合到的位置。source(Tensor): 源张量,要聚合的值。reduce(str): 聚合方式,可选值为 'mean'、'prod'、'amax'、'amin'。默认为 'mean'。include_self(bool, 可选): 是否在聚合中包含索引位置本身的原始值。默认为 True。
返回值:
torch.Tensor: 返回聚合后的张量。
使用示例
实例
import torch
# 创建输入张量
input = torch.randn(3, 3)
# 创建索引和源
index = torch.tensor([0, 0, 0])
source = torch.tensor([1.0, 2.0, 3.0])
# 使用 mean 聚合
output = torch.index_reduce(input, dim=0, index=index, source=source, reduce='mean')
print("输入:")
print(input)
print("n索引:", index)
print("源:", source)
print("nmean 聚合结果:")
print(output)
# 创建输入张量
input = torch.randn(3, 3)
# 创建索引和源
index = torch.tensor([0, 0, 0])
source = torch.tensor([1.0, 2.0, 3.0])
# 使用 mean 聚合
output = torch.index_reduce(input, dim=0, index=index, source=source, reduce='mean')
print("输入:")
print(input)
print("n索引:", index)
print("源:", source)
print("nmean 聚合结果:")
print(output)
输出结果为:
输入:
tensor([[ 0.3456, -0.1234, 0.5678],
[-0.5678, 0.1234, -0.6789],
[ 0.7890, -0.3456, 0.1234]])
索引: tensor([0, 0, 0])
源: tensor([1., 2., 3.])
mean 聚合结果:
tensor([[ 2.5237, 0.2931, 0.3374],
[-0.5678, 0.1234, -0.6789],
[ 0.7890, -0.3456, 0.1234]])
实例
import torch
# 测试不同的聚合方式
input = torch.ones(3)
index = torch.tensor([0, 0, 0])
source = torch.tensor([2.0, 4.0, 6.0])
# prod 聚合
output_prod = torch.index_reduce(input, 0, index, source, reduce='prod')
print("prod 聚合:", output_prod)
# amax 聚合
output_max = torch.index_reduce(input, 0, index, source, reduce='amax')
print("amax 聚合:", output_max)
# amin 聚合
output_min = torch.index_reduce(input, 0, index, source, reduce='amin')
print("amin 聚合:", output_min)
# 测试不同的聚合方式
input = torch.ones(3)
index = torch.tensor([0, 0, 0])
source = torch.tensor([2.0, 4.0, 6.0])
# prod 聚合
output_prod = torch.index_reduce(input, 0, index, source, reduce='prod')
print("prod 聚合:", output_prod)
# amax 聚合
output_max = torch.index_reduce(input, 0, index, source, reduce='amax')
print("amax 聚合:", output_max)
# amin 聚合
output_min = torch.index_reduce(input, 0, index, source, reduce='amin')
print("amin 聚合:", output_min)
输出结果为:
prod 聚合: tensor([48., 1., 1.]) amax 聚合: tensor([7., 1., 1.]) amin 聚合: tensor([3., 1., 1.])
实例
import torch
# include_self 参数
input = torch.tensor([1.0, 10.0, 100.0])
index = torch.tensor([0, 0])
source = torch.tensor([5.0, 5.0])
# 包含自身值 (默认)
output1 = torch.index_reduce(input, 0, index, source, reduce='mean', include_self=True)
print("include_self=True:", output1)
# 不包含自身值
output2 = torch.index_reduce(input, 0, index, source, reduce='mean', include_self=False)
print("include_self=False:", output2)
# include_self 参数
input = torch.tensor([1.0, 10.0, 100.0])
index = torch.tensor([0, 0])
source = torch.tensor([5.0, 5.0])
# 包含自身值 (默认)
output1 = torch.index_reduce(input, 0, index, source, reduce='mean', include_self=True)
print("include_self=True:", output1)
# 不包含自身值
output2 = torch.index_reduce(input, 0, index, source, reduce='mean', include_self=False)
print("include_self=False:", output2)
输出结果为:
include_self=True: tensor([ 3.6667, 10.0000, 100.0000]) include_self=False: tensor([ 5., 10., 100.])
实例
import torch
# 2D张量应用
input = torch.zeros(3, 4)
index = torch.tensor([0, 2, 2])
source = torch.randn(3, 4)
# mean 聚合
output = torch.index_reduce(input, dim=0, index=index, source=source, reduce='mean')
print("输入形状:", input.shape)
print("索引:", index)
print("源形状:", source.shape)
print("结果形状:", output.shape)
print("n结果:")
print(output)
# 2D张量应用
input = torch.zeros(3, 4)
index = torch.tensor([0, 2, 2])
source = torch.randn(3, 4)
# mean 聚合
output = torch.index_reduce(input, dim=0, index=index, source=source, reduce='mean')
print("输入形状:", input.shape)
print("索引:", index)
print("源形状:", source.shape)
print("结果形状:", output.shape)
print("n结果:")
print(output)
注意:torch.index_reduce 不会修改原始输入张量,而是返回一个新的张量。多个索引可以指向同一位置,值会按照指定的聚合方式进行聚合。include_self=False 在某些场景下很有用,可以排除原始值只聚合新添加的值。

Pytorch torch 参考手册