现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.sort 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.sort 是 PyTorch 中用于沿指定维度排序的函数。

函数定义

torch.sort(input, dim, descending, stable, out)

使用示例

实例

import torch

x = torch.tensor([[3, 1, 2], [6, 4, 5]])

# 排序
values, indices = torch.sort(x)

print("排序值:")
print(values)
print("排序索引:")
print(indices)

# 降序排序
values_desc, _ = torch.sort(x, descending=True)
print("降序:")
print(values_desc)

Pytorch torch 参考手册 Pytorch torch 参考手册