PyTorch torch.searchsorted 函数
torch.searchsorted 是 PyTorch 中用于在排序后的张量中搜索元素应插入位置的函数。返回的值是元素应该插入的索引位置。
函数定义
torch.searchsorted(sorted_sequence, values, side='left', out_int32=False, right=False)
参数说明:
sorted_sequence: 已排序的一维或多维张量values: 要搜索的值side: 'left' 或 'right',决定返回左还是右插入位置out_int32: 是否返回 int32 类型right: 已废弃,使用 side 代替
使用示例
实例
import torch
# 创建已排序的序列
sorted_seq = torch.tensor([1, 3, 5, 7, 9])
# 搜索值的位置
values = torch.tensor([3, 6, 8])
y = torch.searchsorted(sorted_seq, values)
print(y)
# 创建已排序的序列
sorted_seq = torch.tensor([1, 3, 5, 7, 9])
# 搜索值的位置
values = torch.tensor([3, 6, 8])
y = torch.searchsorted(sorted_seq, values)
print(y)
输出结果为:
tensor([1, 3, 3]) </p> <div class="example"> <h2 class="example">实例</h2> <div class="example_code"> import torch
# 创建已排序的序列
sorted_seq = torch.tensor([1, 3, 5, 7, 9])
# 使用 side='right' 搜索
values = torch.tensor([3, 6, 8])
y = torch.searchsorted(sorted_seq, values, side='right')
print(y)
</div> </div> <p>输出结果为:</p> <pre> tensor([2, 3, 4])
实例
import torch
# 用于多维数组
sorted_seq = torch.tensor([[1, 3, 5], [2, 4, 6]])
values = torch.tensor([[1.5], [3.5]])
y = torch.searchsorted(sorted_seq, values)
print(y)
# 用于多维数组
sorted_seq = torch.tensor([[1, 3, 5], [2, 4, 6]])
values = torch.tensor([[1.5], [3.5]])
y = torch.searchsorted(sorted_seq, values)
print(y)
输出结果为:
tensor([[1],
[1]])

Pytorch torch 参考手册