PyTorch torch.index_select 函数
torch.index_select 是 PyTorch 中用于沿指定维度选择索引对应元素的函数。
函数定义
torch.index_select(input, dim, index)
使用示例
实例
import torch
x = torch.randn(4, 5)
# 选择第 0 和第 2 行
indices = torch.tensor([0, 2])
result = torch.index_select(x, dim=0, index=indices)
print("结果形状:", result.shape)
x = torch.randn(4, 5)
# 选择第 0 和第 2 行
indices = torch.tensor([0, 2])
result = torch.index_select(x, dim=0, index=indices)
print("结果形状:", result.shape)

Pytorch torch 参考手册