PyTorch torch.topk 函数
torch.topk 是 PyTorch 中用于返回最大的 k 个元素和索引的函数。
函数定义
torch.topk(input, k, dim, largest, sorted, out)
使用示例
实例
import torch
x = torch.tensor([[3, 1, 2], [6, 4, 5]])
# 返回最大的 2 个元素
values, indices = torch.topk(x, k=2)
print("最大的 2 个值:")
print(values)
print("对应的索引:")
print(indices)
x = torch.tensor([[3, 1, 2], [6, 4, 5]])
# 返回最大的 2 个元素
values, indices = torch.topk(x, k=2)
print("最大的 2 个值:")
print(values)
print("对应的索引:")
print(indices)
输出结果为:
最大的 2 个值:
tensor([[3, 2],
[6, 5]])
对应的索引:
tensor([[0, 2],
[0, 2]])

Pytorch torch 参考手册