现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.slice_scatter 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.slice_scatter 是 PyTorch 中用于将源张量的值散布到输入张量切片中的函数。它将 src 的值散布到 input 的指定切片位置。

函数定义

torch.slice_scatter(input, src, dim=0, start=None, end=None, step=1)

参数:

  • input (Tensor): 输入张量,即要修改的张量。
  • src (Tensor): 源张量,要散布到 input 中的值。
  • dim (int, 可选): 散布的维度,默认为 0。
  • start (int, 可选): 起始索引,默认为 None(从0开始)。
  • end (int, 可选): 结束索引,默认为 None(到末尾结束)。
  • step (int, 可选): 步长,默认为 1。

返回值:

  • torch.Tensor: 返回修改后的张量。

使用示例

实例

import torch

# 创建输入张量和源张量
input = torch.zeros(8, 4)
src = torch.ones(2, 4)

# 将 src 散布到 input 的前两行
output = torch.slice_scatter(input, src, dim=0, end=2)

print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n散布结果:")
print(output)

输出结果为:

输入张量:
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

源张量:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.]])

散布结果:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

实例

import torch

# 使用 start 和 end 指定范围
input = torch.zeros(10)
src = torch.tensor([1, 2, 3])

# 将 src 散布到 input 的索引 2-4 位置
output = torch.slice_scatter(input, src, dim=0, start=2, end=5)

print("输入:", input)
print("源:", src)
print("结果:", output)

输出结果为:

输入: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
源: tensor([1., 2., 3.])
结果: tensor([0., 0., 1., 2., 3., 0., 0., 0., 0., 0.])

实例

import torch

# 使用 step 参数
input = torch.zeros(10)
src = torch.tensor([1, 2])

# 步长为2,从索引0开始到6结束
output = torch.slice_scatter(input, src, dim=0, start=0, end=6, step=2)

print("输入:", input)
print("源:", src)
print("步长为2的结果:", output)

输出结果为:

输入: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
源: tensor([1., 2.])
步长为2的结果: tensor([1., 0., 2., 0., 0., 0., 0., 0., 0., 0.])

注意:torch.slice_scatter 不会修改原始输入张量,而是返回一个新的张量。该函数是 torch.slice 的逆操作。


Pytorch torch 参考手册 Pytorch torch 参考手册