现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.select 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.select 是 PyTorch 中用于沿指定维度选择索引对应切片的函数。它返回指定维度上特定索引位置的切片视图。

函数定义

torch.select(input, dim, index)

参数:

  • input (Tensor): 输入张量。
  • dim (int): 要选择的维度。
  • index (int): 要选择的索引。

返回值:

  • torch.Tensor: 返回指定索引位置的切片(维度减1)。

使用示例

实例

import torch

# 创建一个3x4的张量
x = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

# 选择第一维(行)上索引为1的行
y = torch.select(x, dim=0, index=1)

print("原始张量:")
print(x)
print("n选择索引1的行:")
print(y)

输出结果为:

原始张量:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

选择索引1的行:
tensor([5, 6, 7, 8])

实例

import torch

# 创建一个3D张量
x = torch.arange(24).reshape(2, 3, 4)

print("原始3D张量:")
print(x)
print("形状:", x.shape)

# 选择第一个维度(batch)中索引为0的元素
y = torch.select(x, dim=0, index=0)
print("n选择 dim=0, index=0:")
print(y)
print("形状:", y.shape)

输出结果为:

原始3D张量:
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
形状: torch.Size([2, 3, 4])

选择 dim=0, index=0:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
形状: torch.Size([3, 4])

实例

import torch

# 使用负索引
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

# 选择最后一行的等价方式
y = torch.select(x, dim=0, index=-1)

print("原始张量:")
print(x)
print("n选择最后一行 (index=-1):")
print(y)

输出结果为:

原始张量:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

选择最后一行 (index=-1):
tensor([7, 8, 9])

注意:torch.select 返回的是视图,而不是副本,因此操作是高效的。类似的功能也可以使用切片操作来实现,如 x[index]x[index:index+1]


Pytorch torch 参考手册 Pytorch torch 参考手册