PyTorch torch.nanmean 函数
torch.nanmean 是 PyTorch 中用于计算忽略 NaN 值的平均值的函数。
函数定义
torch.nanmean(input, dim, keepdim=False, out=None)
使用示例
实例
import torch
# 创建包含 NaN 的张量
x = torch.tensor([1.0, 2.0, float('nan'), 4.0, 5.0])
# 计算非 NaN 值的平均值
mean = torch.nanmean(x)
print("非 NaN 均值:", mean)
# 创建包含 NaN 的张量
x = torch.tensor([1.0, 2.0, float('nan'), 4.0, 5.0])
# 计算非 NaN 值的平均值
mean = torch.nanmean(x)
print("非 NaN 均值:", mean)
输出结果为:
非 NaN 均值: tensor(3.)

Pytorch torch 参考手册