PyTorch torch.gather 函数
torch.gather 是 PyTorch 中用于沿指定维度收集指定索引元素的函数。
函数定义
torch.gather(input, dim, index, sparse_grad)
使用示例
实例
import torch
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
# 沿 dim=1 收集
index = torch.tensor([[0], [1], [0]])
result = torch.gather(x, dim=1, index=index)
print(result)
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
# 沿 dim=1 收集
index = torch.tensor([[0], [1], [0]])
result = torch.gather(x, dim=1, index=index)
print(result)
输出结果为:
tensor([[1],
[4],
[5]])

Pytorch torch 参考手册