PyTorch torch.diagonal_scatter 函数
torch.diagonal_scatter 是 PyTorch 中用于将值散布到张量对角线位置的函数。它将 src 的值散布到 input 的指定对角线上。
函数定义
torch.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1)
参数:
input(Tensor): 输入张量,即要修改的张量。src(Tensor): 源张量,要散布到对角线位置的值。offset(int, 可选): 对角线偏移量。正值表示上对角线,负值表示下对角线,0表示主对角线。dim1(int, 可选): 第一个维度,默认为 0。dim2(int, 可选): 第二个维度,默认为 1。
返回值:
torch.Tensor: 返回修改后的张量。
使用示例
实例
import torch
# 创建输入张量
input = torch.zeros(4, 4)
src = torch.tensor([1, 2, 3, 4])
# 将值散布到主对角线
output = torch.diagonal_scatter(input, src)
print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n散布到主对角线:")
print(output)
# 创建输入张量
input = torch.zeros(4, 4)
src = torch.tensor([1, 2, 3, 4])
# 将值散布到主对角线
output = torch.diagonal_scatter(input, src)
print("输入张量:")
print(input)
print("n源张量:")
print(src)
print("n散布到主对角线:")
print(output)
输出结果为:
输入张量:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
源张量:
tensor([1, 2, 3, 4])
散布到主对角线:
tensor([[1., 0., 0., 0.],
[0., 2., 0., 0.],
[0., 0., 3., 0.],
[0., 0., 0., 4.]])
实例
import torch
# 创建输入张量
input = torch.zeros(4, 4)
src = torch.tensor([1, 2, 3])
# 将值散布到上对角线 (offset=1)
output = torch.diagonal_scatter(input, src, offset=1)
print("散布到上对角线 (offset=1):")
print(output)
# 将值散布到下对角线 (offset=-1)
output2 = torch.diagonal_scatter(input, src, offset=-1)
print("n散布到下对角线 (offset=-1):")
print(output2)
# 创建输入张量
input = torch.zeros(4, 4)
src = torch.tensor([1, 2, 3])
# 将值散布到上对角线 (offset=1)
output = torch.diagonal_scatter(input, src, offset=1)
print("散布到上对角线 (offset=1):")
print(output)
# 将值散布到下对角线 (offset=-1)
output2 = torch.diagonal_scatter(input, src, offset=-1)
print("n散布到下对角线 (offset=-1):")
print(output2)
输出结果为:
散布到上对角线 (offset=1):
tensor([[0., 1., 0., 0.],
[0., 0., 2., 0.],
[0., 0., 0., 3.],
[0., 0., 0., 0.]])
散布到下对角线 (offset=-1):
tensor([[0., 0., 0., 0.],
[1., 0., 0., 0.],
[0., 2., 0., 0.],
[0., 0., 3., 0.]])
实例
import torch
# 3D张量中使用对角线
input = torch.zeros(3, 4, 4)
src = torch.tensor([10, 20, 30])
# 在指定的两个维度上散布
output = torch.diagonal_scatter(input, src, dim1=1, dim2=2)
print("输入形状:", input.shape)
print("源形状:", src.shape)
print("结果形状:", output.shape)
# 查看第一个batch
print("n第一个batch的结果:")
print(output[0])
# 3D张量中使用对角线
input = torch.zeros(3, 4, 4)
src = torch.tensor([10, 20, 30])
# 在指定的两个维度上散布
output = torch.diagonal_scatter(input, src, dim1=1, dim2=2)
print("输入形状:", input.shape)
print("源形状:", src.shape)
print("结果形状:", output.shape)
# 查看第一个batch
print("n第一个batch的结果:")
print(output[0])
输出结果为:
输入形状: torch.Size([3, 4, 4])
源形状: torch.Size([3])
结果形状: torch.Size([3, 4, 4])
第一个batch的结果:
tensor([[10., 0., 0., 0.],
[0., 20., 0., 0.],
[0., 0., 30., 0.],
[0., 0., 0., 0.]])
注意:torch.diagonal_scatter 不会修改原始输入张量,而是返回一个新的张量。src 的大小必须与对角线元素的数量相匹配。

Pytorch torch 参考手册