现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.searchsorted 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.searchsorted 是 PyTorch 中用于在排序后的张量中搜索元素应插入位置的函数。返回的值是元素应该插入的索引位置。

函数定义

torch.searchsorted(sorted_sequence, values, side='left', out_int32=False, right=False)

参数说明:

  • sorted_sequence: 已排序的一维或多维张量
  • values: 要搜索的值
  • side: 'left' 或 'right',决定返回左还是右插入位置
  • out_int32: 是否返回 int32 类型
  • right: 已废弃,使用 side 代替

使用示例

实例

import torch

# 创建已排序的序列
sorted_seq = torch.tensor([1, 3, 5, 7, 9])

# 搜索值的位置
values = torch.tensor([3, 6, 8])
y = torch.searchsorted(sorted_seq, values)
print(y)

输出结果为:

tensor([1, 3, 3])
</p>
<div class="example">
<h2 class="example">实例</h2>
<div class="example_code">
import torch

# 创建已排序的序列
sorted_seq = torch.tensor([1, 3, 5, 7, 9])

# 使用 side='right' 搜索
values = torch.tensor([3, 6, 8])
y = torch.searchsorted(sorted_seq, values, side='right')
print(y)
</div> </div> <p>输出结果为:</p> <pre> tensor([2, 3, 4])

实例

import torch

# 用于多维数组
sorted_seq = torch.tensor([[1, 3, 5], [2, 4, 6]])
values = torch.tensor([[1.5], [3.5]])
y = torch.searchsorted(sorted_seq, values)
print(y)

输出结果为:

tensor([[1],
        [1]])

Pytorch torch 参考手册 Pytorch torch 参考手册