PyTorch torch.argsort 函数
torch.argsort 是 PyTorch 中用于返回排序后索引的函数。它返回将张量排序后各元素原位置的索引。
函数定义
torch.argsort(input, dim=-1, descending=False, stable=True)
使用示例
实例
import torch
# 创建张量
x = torch.tensor([3, 1, 2])
# 返回排序后的索引
indices = torch.argsort(x)
print(indices)
# 创建张量
x = torch.tensor([3, 1, 2])
# 返回排序后的索引
indices = torch.argsort(x)
print(indices)
输出结果为:
tensor([1, 2, 0])

Pytorch torch 参考手册