PyTorch torch.scatter 函数
torch.scatter 是 PyTorch 中用于将值散布到指定位置的函数。
函数定义
torch.scatter(input, dim, index, src, reduce)
使用示例
实例
import torch
x = torch.zeros(3, 4)
# 将值散布到指定位置
index = torch.tensor([[0], [1], [2]])
src = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
result = torch.scatter(x, dim=0, index=index, src=src)
print(result)
x = torch.zeros(3, 4)
# 将值散布到指定位置
index = torch.tensor([[0], [1], [2]])
src = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
result = torch.scatter(x, dim=0, index=index, src=src)
print(result)

Pytorch torch 参考手册