PyTorch torch.take 函数
torch.take 是 PyTorch 中用于获取给定索引位置元素的函数。它将输入张量视为一维数组,然后返回指定索引位置的元素。
函数定义
torch.take(input, index)
参数:
input(Tensor): 输入张量。index(Tensor): 整数索引张量,指定要获取的元素位置。
返回值:
torch.Tensor: 返回指定索引位置的元素组成的新张量。
使用示例
实例
import torch
# 创建一个2D张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print("原始张量:")
print(x)
# 将2D张量视为一维数组,索引0-8
# 取索引0, 4, 8位置的元素
index = torch.tensor([0, 4, 8])
y = torch.take(x, index)
print("n索引:", index)
print("取出的元素:", y)
# 创建一个2D张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print("原始张量:")
print(x)
# 将2D张量视为一维数组,索引0-8
# 取索引0, 4, 8位置的元素
index = torch.tensor([0, 4, 8])
y = torch.take(x, index)
print("n索引:", index)
print("取出的元素:", y)
输出结果为:
原始张量:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
索引: tensor([0, 4, 8])
取出的元素: tensor([1, 5, 9])
实例
import torch
# 3D张量
x = torch.arange(24).reshape(2, 3, 4)
print("原始3D张量:")
print(x)
print("形状:", x.shape)
# 取多个索引的元素
index = torch.tensor([0, 1, 2, 10, 20, 23])
y = torch.take(x, index)
print("n索引:", index)
print("取出的元素:", y)
# 3D张量
x = torch.arange(24).reshape(2, 3, 4)
print("原始3D张量:")
print(x)
print("形状:", x.shape)
# 取多个索引的元素
index = torch.tensor([0, 1, 2, 10, 20, 23])
y = torch.take(x, index)
print("n索引:", index)
print("取出的元素:", y)
输出结果为:
原始3D张量:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
形状: torch.Size([2, 3, 4])
索引: tensor([ 0, 1, 2, 10, 20, 23])
取出的元素: tensor([ 0, 1, 2, 10, 20, 23])
实例
import torch
# 使用负索引
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 负索引从末尾开始计算
index = torch.tensor([0, -1]) # 第一个和最后一个元素
y = torch.take(x, index)
print("原始:", x)
print("索引 [0, -1]:", y)
# 使用负索引
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 负索引从末尾开始计算
index = torch.tensor([0, -1]) # 第一个和最后一个元素
y = torch.take(x, index)
print("原始:", x)
print("索引 [0, -1]:", y)
输出结果为:
原始: tensor([[1, 2, 3],
[4, 5, 6]])
索引 [0, -1]: tensor([1, 6])
实例
import torch
# 随机选择元素
x = torch.randn(10, 10)
# 随机生成10个索引
index = torch.randint(0, 100, (10,))
print("随机索引:", index)
# 取出对应位置的元素
selected = torch.take(x, index)
print("原始张量形状:", x.shape)
print("选中的元素形状:", selected.shape)
print("选中的元素:", selected)
# 随机选择元素
x = torch.randn(10, 10)
# 随机生成10个索引
index = torch.randint(0, 100, (10,))
print("随机索引:", index)
# 取出对应位置的元素
selected = torch.take(x, index)
print("原始张量形状:", x.shape)
print("选中的元素形状:", selected.shape)
print("选中的元素:", selected)
输出结果为:
随机索引: tensor([12, 45, 67, 82, 35, 59, 92, 7, 28, 73]) 原始张量形状: torch.Size([10, 10]) 选中的元素形状: torch.Size([10]) 选中的元素: tensor([ 0.2345, -0.1234, 0.5678, 1.2345, -0.6789, 0.8901, -0.3456, 0.1234, -0.5678, 0.7890])
注意:torch.take 将输入张量视为扁平化的一维张量来索引。索引必须在有效范围内(0 到 numel-1),也可以使用负索引(-1 表示最后一个元素)。

Pytorch torch 参考手册