PyTorch torch.mean 函数
torch.mean 是 PyTorch 中用于计算张量均值的函数。
函数定义
torch.mean(input, dim, keepdim, dtype, out)
使用示例
实例
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 计算所有元素均值
print("全局均值:", torch.mean(x))
# 沿 dim=0 计算均值
print("dim=0 均值:", torch.mean(x, dim=0))
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 计算所有元素均值
print("全局均值:", torch.mean(x))
# 沿 dim=0 计算均值
print("dim=0 均值:", torch.mean(x, dim=0))
输出结果为:
全局均值: tensor(3.5000) dim=0 均值: tensor([2.5000, 3.5000, 4.5000])

Pytorch torch 参考手册