现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.index_reduce 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.index_reduce 是 PyTorch 中用于将源张量的值按指定方式聚合到指定索引位置的函数。它沿指定维度 dim,在 index 指定的索引位置按指定方式聚合 source 的值。

函数定义

torch.index_reduce(input, dim, index, source, reduce='mean', *, include_self=True)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 索引的维度。
  • index (Tensor): 一维整数张量,指定要聚合到的位置。
  • source (Tensor): 源张量,要聚合的值。
  • reduce (str): 聚合方式,可选值为 'mean'、'prod'、'amax'、'amin'。默认为 'mean'。
  • include_self (bool, 可选): 是否在聚合中包含索引位置本身的原始值。默认为 True。

返回值:

  • torch.Tensor: 返回聚合后的张量。

使用示例

实例

import torch

# 创建输入张量
input = torch.randn(3, 3)

# 创建索引和源
index = torch.tensor([0, 0, 0])
source = torch.tensor([1.0, 2.0, 3.0])

# 使用 mean 聚合
output = torch.index_reduce(input, dim=0, index=index, source=source, reduce='mean')

print("输入:")
print(input)
print("n索引:", index)
print("源:", source)
print("nmean 聚合结果:")
print(output)

输出结果为:

输入:
tensor([[ 0.3456, -0.1234,  0.5678],
        [-0.5678,  0.1234, -0.6789],
        [ 0.7890, -0.3456,  0.1234]])

索引: tensor([0, 0, 0])
源: tensor([1., 2., 3.])

mean 聚合结果:
tensor([[ 2.5237,  0.2931,  0.3374],
        [-0.5678,  0.1234, -0.6789],
        [ 0.7890, -0.3456,  0.1234]])

实例

import torch

# 测试不同的聚合方式
input = torch.ones(3)
index = torch.tensor([0, 0, 0])
source = torch.tensor([2.0, 4.0, 6.0])

# prod 聚合
output_prod = torch.index_reduce(input, 0, index, source, reduce='prod')
print("prod 聚合:", output_prod)

# amax 聚合
output_max = torch.index_reduce(input, 0, index, source, reduce='amax')
print("amax 聚合:", output_max)

# amin 聚合
output_min = torch.index_reduce(input, 0, index, source, reduce='amin')
print("amin 聚合:", output_min)

输出结果为:

prod 聚合: tensor([48.,  1.,  1.])
amax 聚合: tensor([7., 1., 1.])
amin 聚合: tensor([3., 1., 1.])

实例

import torch

# include_self 参数
input = torch.tensor([1.0, 10.0, 100.0])
index = torch.tensor([0, 0])
source = torch.tensor([5.0, 5.0])

# 包含自身值 (默认)
output1 = torch.index_reduce(input, 0, index, source, reduce='mean', include_self=True)
print("include_self=True:", output1)

# 不包含自身值
output2 = torch.index_reduce(input, 0, index, source, reduce='mean', include_self=False)
print("include_self=False:", output2)

输出结果为:

include_self=True: tensor([ 3.6667, 10.0000, 100.0000])
include_self=False: tensor([ 5., 10., 100.])

实例

import torch

# 2D张量应用
input = torch.zeros(3, 4)
index = torch.tensor([0, 2, 2])
source = torch.randn(3, 4)

# mean 聚合
output = torch.index_reduce(input, dim=0, index=index, source=source, reduce='mean')

print("输入形状:", input.shape)
print("索引:", index)
print("源形状:", source.shape)
print("结果形状:", output.shape)
print("n结果:")
print(output)

注意:torch.index_reduce 不会修改原始输入张量,而是返回一个新的张量。多个索引可以指向同一位置,值会按照指定的聚合方式进行聚合。include_self=False 在某些场景下很有用,可以排除原始值只聚合新添加的值。


Pytorch torch 参考手册 Pytorch torch 参考手册