PyTorch torch.all 函数
torch.all 是 PyTorch 中用于判断是否所有元素都为 True 的函数。
函数定义
torch.all(input, dim, keepdim, out)
使用示例
实例
import torch
x = torch.tensor([[True, True, True], [True, False, True]])
# 全局判断
print("全部为 True:", torch.all(x))
# 沿维度判断
print("dim=0 全部为 True:", torch.all(x, dim=0))
x = torch.tensor([[True, True, True], [True, False, True]])
# 全局判断
print("全部为 True:", torch.all(x))
# 沿维度判断
print("dim=0 全部为 True:", torch.all(x, dim=0))
输出结果为:
全部为 True: tensor(False) dim=0 全部为 True: tensor([True, False, True])

Pytorch torch 参考手册