PyTorch torch.sort 函数
torch.sort 是 PyTorch 中用于沿指定维度排序的函数。
函数定义
torch.sort(input, dim, descending, stable, out)
使用示例
实例
import torch
x = torch.tensor([[3, 1, 2], [6, 4, 5]])
# 排序
values, indices = torch.sort(x)
print("排序值:")
print(values)
print("排序索引:")
print(indices)
# 降序排序
values_desc, _ = torch.sort(x, descending=True)
print("降序:")
print(values_desc)
x = torch.tensor([[3, 1, 2], [6, 4, 5]])
# 排序
values, indices = torch.sort(x)
print("排序值:")
print(values)
print("排序索引:")
print(indices)
# 降序排序
values_desc, _ = torch.sort(x, descending=True)
print("降序:")
print(values_desc)

Pytorch torch 参考手册