现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.index_copy 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.index_copy 是 PyTorch 中用于将源张量复制到指定索引位置的函数。它沿指定维度 dim,在 index 指定的索引位置复制 source 的值。

torch.index_add 不同的是,index_copy 是覆盖而非累加。

函数定义

torch.index_copy(input, dim, index, source)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 索引的维度。
  • index (Tensor): 一维整数张量,指定要复制到的位置。
  • source (Tensor): 源张量,要复制的值。

返回值:

  • torch.Tensor: 返回修改后的张量。

使用示例

实例

import torch

# 创建输入张量
input = torch.randn(4, 5)

# 创建索引和源
index = torch.tensor([0, 2, 3])
source = torch.randn(3, 5)

# 沿 dim=0 复制
output = torch.index_copy(input, dim=0, index=index, source=source)

print("输入:")
print(input)
print("n源:")
print(source)
print("n复制到位置 [0, 2, 3] 后的结果:")
print(output)

输出结果为:

输入:
tensor([[ 0.3456, -0.1234,  0.5678, -0.2345,  0.8901],
        [-0.5678,  0.1234, -0.6789,  0.2345, -0.1234],
        [ 0.7890, -0.3456,  0.1234, -0.5678,  0.3456],
        [-0.1234,  0.4567, -0.8901,  0.6789, -0.5678]])

源:
tensor([[-1.2345,  0.5678, -1.2345,  0.5678, -1.2345],
        [ 1.5678, -0.6789,  1.5678, -0.6789,  1.5678],
        [-0.8901,  1.2345, -0.8901,  1.2345, -0.8901]])

复制到位置 [0, 2, 3] 后的结果:
tensor([[-1.2345,  0.5678, -1.2345,  0.5678, -1.2345],
        [-0.5678,  0.1234, -0.6789,  0.2345, -0.1234],
        [ 1.5678, -0.6789,  1.5678, -0.6789,  1.5678],
        [-0.8901,  1.2345, -0.8901,  1.2345, -0.8901]])

实例

import torch

# 沿 dim=1 复制
input = torch.zeros(3, 5)
index = torch.tensor([1, 3])
source = torch.tensor([[10, 20, 30, 40, 50],
                        [60, 70, 80, 90, 100]])

output = torch.index_copy(input, dim=1, index=index, source=source)

print("输入:")
print(input)
print("n源:")
print(source)
print("n复制到位置 [1, 3] 后的结果:")
print(output)

输出结果为:

输入:
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

源:
tensor([[ 10.,  20.,  30.,  40.,  50.],
        [ 60.,  70.,  80.,  90., 100.]])

复制到位置 [1, 3] 后的结果:
tensor([[  0.,  10.,   0.,  20.,   0.],
        [  0.,  60.,   0.,  70.,   0.],
        [  0.,   0.,   0.,   0.,   0.]])

实例

import torch

# 构建大型张量
# 假设需要将多个小批次的结果合并到一个大的batch中

# 目标张量
batch_size = 8
feature_dim = 4
output = torch.zeros(batch_size, feature_dim)

# 模拟多个小批次的结果
mini_batches = [
    torch.randn(2, feature_dim),
    torch.randn(3, feature_dim),
    torch.randn(1, feature_dim)
]

# 每个批次要放置的位置
indices = [0, 2, 5]

# 依次复制每个批次
for idx, batch in zip(indices, mini_batches):
    # 创建对应大小的索引
    index = torch.arange(idx, idx + len(batch))
    output = torch.index_copy(output, dim=0, index=index, source=batch)

print("最终输出形状:", output.shape)
print(output)

输出结果为:

最终输出形状: torch.Size([8, 4])
tensor([[ 0.1234, -0.5678,  0.8901, -0.2345],
        [ 0.6789, -0.1234, -0.5678,  0.3456],
        [ 1.2345, -0.8901,  0.1234, -0.6789],
        [-0.3456,  0.5678, -0.1234,  0.8901],
        [ 1.5678, -0.2345,  0.6789, -0.1234],
        [-0.8901,  0.3456,  0.5678, -0.8901],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])

注意:torch.index_copy 不会修改原始输入张量,而是返回一个新的张量。index_copy 是覆盖操作,与 torch.index_add 的累加操作不同。


Pytorch torch 参考手册 Pytorch torch 参考手册