PyTorch torch.bincount 函数
torch.bincount 是 PyTorch 中用于计算每个非负整数值出现次数的函数。它返回一个新的 1 维张量,其中第 i 个元素表示值 i 在输入中出现的次数。常用于直方图计算和分组统计。
函数定义
torch.bincount(input, weights=None, minlength=0)
使用示例
实例
import torch
# 基础用法:统计每个值出现的次数
x = torch.tensor([0, 1, 1, 2, 2, 2, 3, 3, 4])
counts = torch.bincount(x)
print("输入:", x)
print("计数结果:", counts)
# 输出: tensor([1, 2, 3, 2, 1])
# 使用 weights 参数
x = torch.tensor([0, 1, 1, 2, 2, 2])
weights = torch.tensor([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
weighted_counts = torch.bincount(x, weights=weights)
print("加权计数:", weighted_counts)
# 输出: tensor([1., 5., 6.])
# 设置最小长度
x = torch.tensor([0])
counts = torch.bincount(x, minlength=5)
print("最小长度5:", counts)
# 输出: tensor([1, 0, 0, 0, 0])
# 基础用法:统计每个值出现的次数
x = torch.tensor([0, 1, 1, 2, 2, 2, 3, 3, 4])
counts = torch.bincount(x)
print("输入:", x)
print("计数结果:", counts)
# 输出: tensor([1, 2, 3, 2, 1])
# 使用 weights 参数
x = torch.tensor([0, 1, 1, 2, 2, 2])
weights = torch.tensor([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
weighted_counts = torch.bincount(x, weights=weights)
print("加权计数:", weighted_counts)
# 输出: tensor([1., 5., 6.])
# 设置最小长度
x = torch.tensor([0])
counts = torch.bincount(x, minlength=5)
print("最小长度5:", counts)
# 输出: tensor([1, 0, 0, 0, 0])

Pytorch torch 参考手册