现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.topk 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.topk 是 PyTorch 中用于返回最大的 k 个元素和索引的函数。

函数定义

torch.topk(input, k, dim, largest, sorted, out)

使用示例

实例

import torch

x = torch.tensor([[3, 1, 2], [6, 4, 5]])

# 返回最大的 2 个元素
values, indices = torch.topk(x, k=2)

print("最大的 2 个值:")
print(values)
print("对应的索引:")
print(indices)

输出结果为:

最大的 2 个值:
tensor([[3, 2],
        [6, 5]])
对应的索引:
tensor([[0, 2],
        [0, 2]])

Pytorch torch 参考手册 Pytorch torch 参考手册