现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.bucketize 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.bucketize 是 PyTorch 中用于桶排序索引的函数。它将输入值映射到对应的桶索引,基于边界数组进行二分查找。常用于将连续值离散化或分箱操作。

函数定义

torch.bucketize(input, boundaries, right=False, out_int32=False, **kwargs)

使用示例

实例

import torch

# 基础用法:桶排序索引
boundaries = torch.tensor([0, 1, 2, 3, 4])
values = torch.tensor([0.5, 1.5, 2.5, 3.5, 0.2, 2.0])

result = torch.bucketize(values, boundaries)
print("边界:", boundaries)
print("值:", values)
print("桶索引:", result)
# 输出: tensor([1, 2, 3, 4, 0, 2])

# right=True 表示边界为右闭区间
result_right = torch.bucketize(values, boundaries, right=True)
print("right=True 桶索引:", result_right)
# 输出: tensor([0, 1, 2, 3, 0, 1])

# 多维输入
values = torch.tensor([[0.5, 1.5], [2.5, 3.5]])
result = torch.bucketize(values, boundaries)
print("多维输入结果:", result)

Pytorch torch 参考手册 Pytorch torch 参考手册