PyTorch torch.select 函数
torch.select 是 PyTorch 中用于沿指定维度选择索引对应切片的函数。它返回指定维度上特定索引位置的切片视图。
函数定义
torch.select(input, dim, index)
参数:
input(Tensor): 输入张量。dim(int): 要选择的维度。index(int): 要选择的索引。
返回值:
torch.Tensor: 返回指定索引位置的切片(维度减1)。
使用示例
实例
import torch
# 创建一个3x4的张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 选择第一维(行)上索引为1的行
y = torch.select(x, dim=0, index=1)
print("原始张量:")
print(x)
print("n选择索引1的行:")
print(y)
# 创建一个3x4的张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 选择第一维(行)上索引为1的行
y = torch.select(x, dim=0, index=1)
print("原始张量:")
print(x)
print("n选择索引1的行:")
print(y)
输出结果为:
原始张量:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
选择索引1的行:
tensor([5, 6, 7, 8])
实例
import torch
# 创建一个3D张量
x = torch.arange(24).reshape(2, 3, 4)
print("原始3D张量:")
print(x)
print("形状:", x.shape)
# 选择第一个维度(batch)中索引为0的元素
y = torch.select(x, dim=0, index=0)
print("n选择 dim=0, index=0:")
print(y)
print("形状:", y.shape)
# 创建一个3D张量
x = torch.arange(24).reshape(2, 3, 4)
print("原始3D张量:")
print(x)
print("形状:", x.shape)
# 选择第一个维度(batch)中索引为0的元素
y = torch.select(x, dim=0, index=0)
print("n选择 dim=0, index=0:")
print(y)
print("形状:", y.shape)
输出结果为:
原始3D张量:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
形状: torch.Size([2, 3, 4])
选择 dim=0, index=0:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
形状: torch.Size([3, 4])
实例
import torch
# 使用负索引
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 选择最后一行的等价方式
y = torch.select(x, dim=0, index=-1)
print("原始张量:")
print(x)
print("n选择最后一行 (index=-1):")
print(y)
# 使用负索引
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 选择最后一行的等价方式
y = torch.select(x, dim=0, index=-1)
print("原始张量:")
print(x)
print("n选择最后一行 (index=-1):")
print(y)
输出结果为:
原始张量:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
选择最后一行 (index=-1):
tensor([7, 8, 9])
注意:torch.select 返回的是视图,而不是副本,因此操作是高效的。类似的功能也可以使用切片操作来实现,如 x[index] 或 x[index:index+1]。

Pytorch torch 参考手册