PyTorch torch.median 函数
torch.median 是 PyTorch 中用于返回张量的中位数的函数。中位数是将数据排序后位于中间位置的值。
函数定义
torch.median(input, dim, keepdim=False)
使用示例
实例
import torch
x = torch.tensor([4, 2, 1, 3, 5])
# 返回所有元素的中位数
print("全局中位数:", torch.median(x))
# 沿 dim=0 中位数
y = torch.tensor([[1, 3, 2], [4, 1, 3]])
print("dim=0 中位数:", torch.median(y, dim=0))
x = torch.tensor([4, 2, 1, 3, 5])
# 返回所有元素的中位数
print("全局中位数:", torch.median(x))
# 沿 dim=0 中位数
y = torch.tensor([[1, 3, 2], [4, 1, 3]])
print("dim=0 中位数:", torch.median(y, dim=0))
输出结果为:
全局中位数: tensor(3) dim=0 中位数: torch.return_types.median(values=tensor([1, 1, 2]), indices=tensor([0, 1, 0]))

Pytorch torch 参考手册