PyTorch torch.slice_scatter 函数
torch.slice_scatter 是 PyTorch 中用于将源张量的值散布到输入张量切片中的函数。它将 src 的值散布到 input 的指定切片位置。
函数定义
torch.slice_scatter(input, src, dim=0, start=None, end=None, step=1)
参数:
input(Tensor): 输入张量,即要修改的张量。src(Tensor): 源张量,要散布到 input 中的值。dim(int, 可选): 散布的维度,默认为 0。start(int, 可选): 起始索引,默认为 None(从0开始)。end(int, 可选): 结束索引,默认为 None(到末尾结束)。step(int, 可选): 步长,默认为 1。
返回值:
torch.Tensor: 返回修改后的张量。
使用示例
实例
import torch
# 创建输入张量和源张量
input = torch.zeros(8, 4)
src = torch.ones(2, 4)
# 将 src 散布到 input 的前两行
output = torch.slice_scatter(input, src, dim=0, end=2)
print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n散布结果:")
print(output)
# 创建输入张量和源张量
input = torch.zeros(8, 4)
src = torch.ones(2, 4)
# 将 src 散布到 input 的前两行
output = torch.slice_scatter(input, src, dim=0, end=2)
print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n散布结果:")
print(output)
输出结果为:
输入张量:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
源张量:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]])
散布结果:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
实例
import torch
# 使用 start 和 end 指定范围
input = torch.zeros(10)
src = torch.tensor([1, 2, 3])
# 将 src 散布到 input 的索引 2-4 位置
output = torch.slice_scatter(input, src, dim=0, start=2, end=5)
print("输入:", input)
print("源:", src)
print("结果:", output)
# 使用 start 和 end 指定范围
input = torch.zeros(10)
src = torch.tensor([1, 2, 3])
# 将 src 散布到 input 的索引 2-4 位置
output = torch.slice_scatter(input, src, dim=0, start=2, end=5)
print("输入:", input)
print("源:", src)
print("结果:", output)
输出结果为:
输入: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) 源: tensor([1., 2., 3.]) 结果: tensor([0., 0., 1., 2., 3., 0., 0., 0., 0., 0.])
实例
import torch
# 使用 step 参数
input = torch.zeros(10)
src = torch.tensor([1, 2])
# 步长为2,从索引0开始到6结束
output = torch.slice_scatter(input, src, dim=0, start=0, end=6, step=2)
print("输入:", input)
print("源:", src)
print("步长为2的结果:", output)
# 使用 step 参数
input = torch.zeros(10)
src = torch.tensor([1, 2])
# 步长为2,从索引0开始到6结束
output = torch.slice_scatter(input, src, dim=0, start=0, end=6, step=2)
print("输入:", input)
print("源:", src)
print("步长为2的结果:", output)
输出结果为:
输入: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) 源: tensor([1., 2.]) 步长为2的结果: tensor([1., 0., 2., 0., 0., 0., 0., 0., 0., 0.])
注意:torch.slice_scatter 不会修改原始输入张量,而是返回一个新的张量。该函数是 torch.slice 的逆操作。

Pytorch torch 参考手册