PyTorch torch.count_nonzero 函数
torch.count_nonzero 是 PyTorch 中用于返回张量中非零元素数量的函数。
函数定义
torch.count_nonzero(input, dim)
使用示例
实例
import torch
x = torch.tensor([1, 0, 2, 0, 3, 0, 0, 4])
# 返回非零元素的数量
print("非零元素数量:", torch.count_nonzero(x))
# 沿 dim=0 非零元素数量
y = torch.tensor([[1, 0, 2], [0, 3, 0], [4, 0, 5]])
print("非零元素数量:", torch.count_nonzero(y))
print("dim=0 非零元素数量:", torch.count_nonzero(y, dim=0))
print("dim=1 非零元素数量:", torch.count_nonzero(y, dim=1))
x = torch.tensor([1, 0, 2, 0, 3, 0, 0, 4])
# 返回非零元素的数量
print("非零元素数量:", torch.count_nonzero(x))
# 沿 dim=0 非零元素数量
y = torch.tensor([[1, 0, 2], [0, 3, 0], [4, 0, 5]])
print("非零元素数量:", torch.count_nonzero(y))
print("dim=0 非零元素数量:", torch.count_nonzero(y, dim=0))
print("dim=1 非零元素数量:", torch.count_nonzero(y, dim=1))
输出结果为:
非零元素数量: tensor(4) 非零元素数量: tensor(5) dim=0 非零元素数量: tensor([2, 1, 2]) dim=1 非零元素数量: tensor([2, 1, 2])

Pytorch torch 参考手册