现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.take_along_dim 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.take_along_dim 是 PyTorch 中用于沿指定维度获取索引位置元素的函数。它根据 indices 中指定的索引沿 dim 维度从 input 中取值。

函数定义

torch.take_along_dim(input, indices, dim)

参数:

  • input (Tensor): 输入张量。
  • indices (Tensor): 索引张量,指定要获取的元素位置。形状必须与 input 在 dim 维度上兼容。
  • dim (int): 要沿着的维度。

返回值:

  • torch.Tensor: 返回按索引取出的元素组成的新张量。

使用示例

实例

import torch

# 创建一个2D张量
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

print("原始张量:")
print(x)

# 沿 dim=1 取元素
indices = torch.tensor([[0, 1, 2],
                        [2, 1, 0],
                        [0, 0, 0]])
y = torch.take_along_dim(x, indices, dim=1)

print("n索引:")
print(indices)
print("n沿 dim=1 取出的元素:")
print(y)

输出结果为:

原始张量:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

索引:
tensor([[0, 1, 2],
        [2, 1, 0],
        [0, 0, 0]])

沿 dim=1 取出的元素:
tensor([[1, 2, 3],
        [6, 5, 4],
        [7, 7, 7]])

实例

import torch

# 沿 dim=0 取元素
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

# 每列取不同的行
indices = torch.tensor([[0, 1, 2],
                        [2, 0, 1],
                        [1, 2, 0]])
y = torch.take_along_dim(x, indices, dim=0)

print("原始张量:")
print(x)
print("n索引:")
print(indices)
print("n沿 dim=0 取出的元素:")
print(y)

输出结果为:

原始张量:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

索引:
tensor([[0, 1, 2],
        [2, 0, 1],
        [1, 2, 0]])

沿 dim=0 取出的元素:
tensor([[1, 5, 9],
        [7, 2, 6],
        [4, 8, 3]])
</p>

<div class="example">
<h2 class="example">实例</h2>
<div class="example_code">
import torch

# 在3D张量上使用
x = torch.arange(24).reshape(2, 3, 4)
print("原始形状:", x.shape)

# 沿 dim=1 取元素
indices = torch.tensor([[0, 1, 2],
                        [2, 0, 1]])
y = torch.take_along_dim(x, indices, dim=1)

print("索引形状:", indices.shape)
print("结果形状:", y.shape)
print("n结果:")
print(y)
</div> </div> <p>输出结果为:</p> <pre> 原始形状: torch.Size([2, 3, 4]) 索引形状: torch.Size([2, 3]) 结果形状: torch.Size([2, 3, 4]) 结果: tensor([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[16, 17, 18, 19], [12, 13, 14, 15], [20, 21, 22, 23]]])

实例

import torch

# 应用:沿批次维度选择特定的元素
# 比如在注意力机制中选择特定的key-value

batch_size = 2
num_heads = 3
seq_len = 4
head_dim = 5

# 模拟query的注意力权重
attn_weights = torch.randn(batch_size, num_heads, seq_len)

# 对每个head取top-k的索引
k = 2
indices = torch.argsort(attn_weights, dim=-1, descending=True)[..., :k]
print("索引形状:", indices.shape)

# 模拟value张量
value = torch.randn(batch_size, num_heads, seq_len, head_dim)

# 沿seq_len维度取对应的value
selected_value = torch.take_along_dim(value, indices.unsqueeze(-1).expand(-1, -1, -1, head_dim), dim=2)

print("Value形状:", value.shape)
print("选择的Value形状:", selected_value.shape)

输出结果为:

索引形状: torch.Size([2, 3, 4])
Value形状: torch.Size([2, 3, 4, 5])
选择的Value形状: torch.Size([2, 3, 2, 5])


注意:torch.take_along_dim 允许按维度进行索引,这比 torch.take 更加灵活,因为后者总是将张量视为一维的。


Pytorch torch 参考手册 Pytorch torch 参考手册