PyTorch torch.index_copy 函数
torch.index_copy 是 PyTorch 中用于将源张量复制到指定索引位置的函数。它沿指定维度 dim,在 index 指定的索引位置复制 source 的值。
与 torch.index_add 不同的是,index_copy 是覆盖而非累加。
函数定义
torch.index_copy(input, dim, index, source)
参数:
input(Tensor): 输入张量。dim(int): 索引的维度。index(Tensor): 一维整数张量,指定要复制到的位置。source(Tensor): 源张量,要复制的值。
返回值:
torch.Tensor: 返回修改后的张量。
使用示例
实例
import torch
# 创建输入张量
input = torch.randn(4, 5)
# 创建索引和源
index = torch.tensor([0, 2, 3])
source = torch.randn(3, 5)
# 沿 dim=0 复制
output = torch.index_copy(input, dim=0, index=index, source=source)
print("输入:")
print(input)
print("n源:")
print(source)
print("n复制到位置 [0, 2, 3] 后的结果:")
print(output)
# 创建输入张量
input = torch.randn(4, 5)
# 创建索引和源
index = torch.tensor([0, 2, 3])
source = torch.randn(3, 5)
# 沿 dim=0 复制
output = torch.index_copy(input, dim=0, index=index, source=source)
print("输入:")
print(input)
print("n源:")
print(source)
print("n复制到位置 [0, 2, 3] 后的结果:")
print(output)
输出结果为:
输入:
tensor([[ 0.3456, -0.1234, 0.5678, -0.2345, 0.8901],
[-0.5678, 0.1234, -0.6789, 0.2345, -0.1234],
[ 0.7890, -0.3456, 0.1234, -0.5678, 0.3456],
[-0.1234, 0.4567, -0.8901, 0.6789, -0.5678]])
源:
tensor([[-1.2345, 0.5678, -1.2345, 0.5678, -1.2345],
[ 1.5678, -0.6789, 1.5678, -0.6789, 1.5678],
[-0.8901, 1.2345, -0.8901, 1.2345, -0.8901]])
复制到位置 [0, 2, 3] 后的结果:
tensor([[-1.2345, 0.5678, -1.2345, 0.5678, -1.2345],
[-0.5678, 0.1234, -0.6789, 0.2345, -0.1234],
[ 1.5678, -0.6789, 1.5678, -0.6789, 1.5678],
[-0.8901, 1.2345, -0.8901, 1.2345, -0.8901]])
实例
import torch
# 沿 dim=1 复制
input = torch.zeros(3, 5)
index = torch.tensor([1, 3])
source = torch.tensor([[10, 20, 30, 40, 50],
[60, 70, 80, 90, 100]])
output = torch.index_copy(input, dim=1, index=index, source=source)
print("输入:")
print(input)
print("n源:")
print(source)
print("n复制到位置 [1, 3] 后的结果:")
print(output)
# 沿 dim=1 复制
input = torch.zeros(3, 5)
index = torch.tensor([1, 3])
source = torch.tensor([[10, 20, 30, 40, 50],
[60, 70, 80, 90, 100]])
output = torch.index_copy(input, dim=1, index=index, source=source)
print("输入:")
print(input)
print("n源:")
print(source)
print("n复制到位置 [1, 3] 后的结果:")
print(output)
输出结果为:
输入:
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
源:
tensor([[ 10., 20., 30., 40., 50.],
[ 60., 70., 80., 90., 100.]])
复制到位置 [1, 3] 后的结果:
tensor([[ 0., 10., 0., 20., 0.],
[ 0., 60., 0., 70., 0.],
[ 0., 0., 0., 0., 0.]])
实例
import torch
# 构建大型张量
# 假设需要将多个小批次的结果合并到一个大的batch中
# 目标张量
batch_size = 8
feature_dim = 4
output = torch.zeros(batch_size, feature_dim)
# 模拟多个小批次的结果
mini_batches = [
torch.randn(2, feature_dim),
torch.randn(3, feature_dim),
torch.randn(1, feature_dim)
]
# 每个批次要放置的位置
indices = [0, 2, 5]
# 依次复制每个批次
for idx, batch in zip(indices, mini_batches):
# 创建对应大小的索引
index = torch.arange(idx, idx + len(batch))
output = torch.index_copy(output, dim=0, index=index, source=batch)
print("最终输出形状:", output.shape)
print(output)
# 构建大型张量
# 假设需要将多个小批次的结果合并到一个大的batch中
# 目标张量
batch_size = 8
feature_dim = 4
output = torch.zeros(batch_size, feature_dim)
# 模拟多个小批次的结果
mini_batches = [
torch.randn(2, feature_dim),
torch.randn(3, feature_dim),
torch.randn(1, feature_dim)
]
# 每个批次要放置的位置
indices = [0, 2, 5]
# 依次复制每个批次
for idx, batch in zip(indices, mini_batches):
# 创建对应大小的索引
index = torch.arange(idx, idx + len(batch))
output = torch.index_copy(output, dim=0, index=index, source=batch)
print("最终输出形状:", output.shape)
print(output)
输出结果为:
最终输出形状: torch.Size([8, 4])
tensor([[ 0.1234, -0.5678, 0.8901, -0.2345],
[ 0.6789, -0.1234, -0.5678, 0.3456],
[ 1.2345, -0.8901, 0.1234, -0.6789],
[-0.3456, 0.5678, -0.1234, 0.8901],
[ 1.5678, -0.2345, 0.6789, -0.1234],
[-0.8901, 0.3456, 0.5678, -0.8901],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000]])
注意:torch.index_copy 不会修改原始输入张量,而是返回一个新的张量。index_copy 是覆盖操作,与 torch.index_add 的累加操作不同。

Pytorch torch 参考手册