PyTorch torch.max 函数
torch.max 是 PyTorch 中用于计算张量最大值的函数。
函数定义
torch.max(input, dim, keepdim, out)
使用示例
实例
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 全局最大值
print("全局最大:", torch.max(x))
# 沿维度最大值
print("dim=0 最大:", torch.max(x, dim=0))
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 全局最大值
print("全局最大:", torch.max(x))
# 沿维度最大值
print("dim=0 最大:", torch.max(x, dim=0))
输出结果为:
全局最大: tensor(6) dim=0 最大: torch.return_types.max(values=tensor([4, 5, 6]), indices=tensor([1, 1, 1]))

Pytorch torch 参考手册