现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.select_scatter 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.select_scatter 是 PyTorch 中用于将源张量散布到指定索引位置的函数。它将 src 的值散布到 input 的指定维度 dim 和索引 index 位置。

函数定义

torch.select_scatter(input, src, dim, index)

参数:

  • input (Tensor): 输入张量,即要修改的张量。
  • src (Tensor): 源张量,要散布到 input 中的值。
  • dim (int): 散布的维度。
  • index (int): 散布的位置索引。

返回值:

  • torch.Tensor: 返回修改后的张量。

使用示例

实例

import torch

# 创建输入张量和源张量
input = torch.zeros(4, 4)
src = torch.ones(4)

# 在第一维度(行)索引为1的位置散布 src
output = torch.select_scatter(input, src, dim=0, index=1)

print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n在索引1位置散布结果:")
print(output)

输出结果为:

输入张量:
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

源张量:
tensor([1., 1., 1., 1.])

在索引1位置散布结果:
tensor([[0., 0., 0., 0.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

实例

import torch

# 创建3D张量
input = torch.zeros(3, 4, 5)
src = torch.ones(5)

# 在第二维度(索引2)散布
output = torch.select_scatter(input, src, dim=1, index=2)

print("输入形状:", input.shape)
print("源形状:", src.shape)
print("结果形状:", output.shape)
print("n结果 (dim=1, index=2):")
print(output[0])

输出结果为:

输入形状: torch.Size([3, 4, 5])
源形状: torch.Size([5])
结果形状: torch.Size([3, 4, 5])

结果 (dim=1, index=2):
tensor([1., 1., 1., 1., 1.])

实例

import torch

# 使用不同的值
input = torch.arange(16).reshape(4, 4).float()
src = torch.tensor([100, 100, 100, 100])

# 在最后一行散布
output = torch.select_scatter(input, src, dim=0, index=-1)

print("原始张量:")
print(input)
print("n散布到索引-1位置:")
print(output)

输出结果为:

原始张量:
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])

散布到索引-1位置:
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [100., 100., 100., 100.]])

注意:torch.select_scatter 不会修改原始输入张量,而是返回一个新的张量。src 的维度必须与要替换的切片的维度相匹配。


Pytorch torch 参考手册 Pytorch torch 参考手册