现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.diagonal_scatter 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.diagonal_scatter 是 PyTorch 中用于将值散布到张量对角线位置的函数。它将 src 的值散布到 input 的指定对角线上。

函数定义

torch.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1)

参数:

  • input (Tensor): 输入张量,即要修改的张量。
  • src (Tensor): 源张量,要散布到对角线位置的值。
  • offset (int, 可选): 对角线偏移量。正值表示上对角线,负值表示下对角线,0表示主对角线。
  • dim1 (int, 可选): 第一个维度,默认为 0。
  • dim2 (int, 可选): 第二个维度,默认为 1。

返回值:

  • torch.Tensor: 返回修改后的张量。

使用示例

实例

import torch

# 创建输入张量
input = torch.zeros(4, 4)
src = torch.tensor([1, 2, 3, 4])

# 将值散布到主对角线
output = torch.diagonal_scatter(input, src)

print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n散布到主对角线:")
print(output)

输出结果为:

输入张量:
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

源张量:
tensor([1, 2, 3, 4])

散布到主对角线:
tensor([[1., 0., 0., 0.],
        [0., 2., 0., 0.],
        [0., 0., 3., 0.],
        [0., 0., 0., 4.]])

实例

import torch

# 创建输入张量
input = torch.zeros(4, 4)
src = torch.tensor([1, 2, 3])

# 将值散布到上对角线 (offset=1)
output = torch.diagonal_scatter(input, src, offset=1)

print("散布到上对角线 (offset=1):")
print(output)

# 将值散布到下对角线 (offset=-1)
output2 = torch.diagonal_scatter(input, src, offset=-1)

print("n散布到下对角线 (offset=-1):")
print(output2)

输出结果为:

散布到上对角线 (offset=1):
tensor([[0., 1., 0., 0.],
        [0., 0., 2., 0.],
        [0., 0., 0., 3.],
        [0., 0., 0., 0.]])

散布到下对角线 (offset=-1):
tensor([[0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 2., 0., 0.],
        [0., 0., 3., 0.]])

实例

import torch

# 3D张量中使用对角线
input = torch.zeros(3, 4, 4)
src = torch.tensor([10, 20, 30])

# 在指定的两个维度上散布
output = torch.diagonal_scatter(input, src, dim1=1, dim2=2)

print("输入形状:", input.shape)
print("源形状:", src.shape)
print("结果形状:", output.shape)

# 查看第一个batch
print("n第一个batch的结果:")
print(output[0])

输出结果为:

输入形状: torch.Size([3, 4, 4])
源形状: torch.Size([3])
结果形状: torch.Size([3, 4, 4])

第一个batch的结果:
tensor([[10., 0., 0., 0.],
        [0., 20., 0., 0.],
        [0., 0., 30., 0.],
        [0., 0., 0., 0.]])

注意:torch.diagonal_scatter 不会修改原始输入张量,而是返回一个新的张量。src 的大小必须与对角线元素的数量相匹配。


Pytorch torch 参考手册 Pytorch torch 参考手册