PyTorch torch.bucketize 函数
torch.bucketize 是 PyTorch 中用于桶排序索引的函数。它将输入值映射到对应的桶索引,基于边界数组进行二分查找。常用于将连续值离散化或分箱操作。
函数定义
torch.bucketize(input, boundaries, right=False, out_int32=False, **kwargs)
使用示例
实例
import torch
# 基础用法:桶排序索引
boundaries = torch.tensor([0, 1, 2, 3, 4])
values = torch.tensor([0.5, 1.5, 2.5, 3.5, 0.2, 2.0])
result = torch.bucketize(values, boundaries)
print("边界:", boundaries)
print("值:", values)
print("桶索引:", result)
# 输出: tensor([1, 2, 3, 4, 0, 2])
# right=True 表示边界为右闭区间
result_right = torch.bucketize(values, boundaries, right=True)
print("right=True 桶索引:", result_right)
# 输出: tensor([0, 1, 2, 3, 0, 1])
# 多维输入
values = torch.tensor([[0.5, 1.5], [2.5, 3.5]])
result = torch.bucketize(values, boundaries)
print("多维输入结果:", result)
# 基础用法:桶排序索引
boundaries = torch.tensor([0, 1, 2, 3, 4])
values = torch.tensor([0.5, 1.5, 2.5, 3.5, 0.2, 2.0])
result = torch.bucketize(values, boundaries)
print("边界:", boundaries)
print("值:", values)
print("桶索引:", result)
# 输出: tensor([1, 2, 3, 4, 0, 2])
# right=True 表示边界为右闭区间
result_right = torch.bucketize(values, boundaries, right=True)
print("right=True 桶索引:", result_right)
# 输出: tensor([0, 1, 2, 3, 0, 1])
# 多维输入
values = torch.tensor([[0.5, 1.5], [2.5, 3.5]])
result = torch.bucketize(values, boundaries)
print("多维输入结果:", result)

Pytorch torch 参考手册