PyTorch torch.select_scatter 函数
torch.select_scatter 是 PyTorch 中用于将源张量散布到指定索引位置的函数。它将 src 的值散布到 input 的指定维度 dim 和索引 index 位置。
函数定义
torch.select_scatter(input, src, dim, index)
参数:
input(Tensor): 输入张量,即要修改的张量。src(Tensor): 源张量,要散布到 input 中的值。dim(int): 散布的维度。index(int): 散布的位置索引。
返回值:
torch.Tensor: 返回修改后的张量。
使用示例
实例
import torch
# 创建输入张量和源张量
input = torch.zeros(4, 4)
src = torch.ones(4)
# 在第一维度(行)索引为1的位置散布 src
output = torch.select_scatter(input, src, dim=0, index=1)
print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n在索引1位置散布结果:")
print(output)
# 创建输入张量和源张量
input = torch.zeros(4, 4)
src = torch.ones(4)
# 在第一维度(行)索引为1的位置散布 src
output = torch.select_scatter(input, src, dim=0, index=1)
print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n在索引1位置散布结果:")
print(output)
输出结果为:
输入张量:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
源张量:
tensor([1., 1., 1., 1.])
在索引1位置散布结果:
tensor([[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
实例
import torch
# 创建3D张量
input = torch.zeros(3, 4, 5)
src = torch.ones(5)
# 在第二维度(索引2)散布
output = torch.select_scatter(input, src, dim=1, index=2)
print("输入形状:", input.shape)
print("源形状:", src.shape)
print("结果形状:", output.shape)
print("n结果 (dim=1, index=2):")
print(output[0])
# 创建3D张量
input = torch.zeros(3, 4, 5)
src = torch.ones(5)
# 在第二维度(索引2)散布
output = torch.select_scatter(input, src, dim=1, index=2)
print("输入形状:", input.shape)
print("源形状:", src.shape)
print("结果形状:", output.shape)
print("n结果 (dim=1, index=2):")
print(output[0])
输出结果为:
输入形状: torch.Size([3, 4, 5]) 源形状: torch.Size([5]) 结果形状: torch.Size([3, 4, 5]) 结果 (dim=1, index=2): tensor([1., 1., 1., 1., 1.])
实例
import torch
# 使用不同的值
input = torch.arange(16).reshape(4, 4).float()
src = torch.tensor([100, 100, 100, 100])
# 在最后一行散布
output = torch.select_scatter(input, src, dim=0, index=-1)
print("原始张量:")
print(input)
print("n散布到索引-1位置:")
print(output)
# 使用不同的值
input = torch.arange(16).reshape(4, 4).float()
src = torch.tensor([100, 100, 100, 100])
# 在最后一行散布
output = torch.select_scatter(input, src, dim=0, index=-1)
print("原始张量:")
print(input)
print("n散布到索引-1位置:")
print(output)
输出结果为:
原始张量:
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])
散布到索引-1位置:
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[100., 100., 100., 100.]])
注意:torch.select_scatter 不会修改原始输入张量,而是返回一个新的张量。src 的维度必须与要替换的切片的维度相匹配。

Pytorch torch 参考手册