现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.take 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.take 是 PyTorch 中用于获取给定索引位置元素的函数。它将输入张量视为一维数组,然后返回指定索引位置的元素。

函数定义

torch.take(input, index)

参数:

  • input (Tensor): 输入张量。
  • index (Tensor): 整数索引张量,指定要获取的元素位置。

返回值:

  • torch.Tensor: 返回指定索引位置的元素组成的新张量。

使用示例

实例

import torch

# 创建一个2D张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

print("原始张量:")
print(x)

# 将2D张量视为一维数组,索引0-8
# 取索引0, 4, 8位置的元素
index = torch.tensor([0, 4, 8])
y = torch.take(x, index)

print("n索引:", index)
print("取出的元素:", y)

输出结果为:

原始张量:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

索引: tensor([0, 4, 8])
取出的元素: tensor([1, 5, 9])

实例

import torch

# 3D张量
x = torch.arange(24).reshape(2, 3, 4)

print("原始3D张量:")
print(x)
print("形状:", x.shape)

# 取多个索引的元素
index = torch.tensor([0, 1, 2, 10, 20, 23])
y = torch.take(x, index)

print("n索引:", index)
print("取出的元素:", y)

输出结果为:

原始3D张量:
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8, 9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
形状: torch.Size([2, 3, 4])

索引: tensor([ 0,  1,  2, 10, 20, 23])
取出的元素: tensor([ 0,  1,  2, 10, 20, 23])

实例

import torch

# 使用负索引
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

# 负索引从末尾开始计算
index = torch.tensor([0, -1])  # 第一个和最后一个元素
y = torch.take(x, index)

print("原始:", x)
print("索引 [0, -1]:", y)

输出结果为:

原始: tensor([[1, 2, 3],
        [4, 5, 6]])
索引 [0, -1]: tensor([1, 6])

实例

import torch

# 随机选择元素
x = torch.randn(10, 10)

# 随机生成10个索引
index = torch.randint(0, 100, (10,))
print("随机索引:", index)

# 取出对应位置的元素
selected = torch.take(x, index)

print("原始张量形状:", x.shape)
print("选中的元素形状:", selected.shape)
print("选中的元素:", selected)

输出结果为:

随机索引: tensor([12, 45, 67, 82, 35, 59, 92,  7, 28, 73])
原始张量形状: torch.Size([10, 10])
选中的元素形状: torch.Size([10])
选中的元素: tensor([ 0.2345, -0.1234,  0.5678,  1.2345, -0.6789,  0.8901, -0.3456,  0.1234, -0.5678,  0.7890])

注意:torch.take 将输入张量视为扁平化的一维张量来索引。索引必须在有效范围内(0 到 numel-1),也可以使用负索引(-1 表示最后一个元素)。


Pytorch torch 参考手册 Pytorch torch 参考手册