PyTorch torch.take_along_dim 函数
torch.take_along_dim 是 PyTorch 中用于沿指定维度获取索引位置元素的函数。它根据 indices 中指定的索引沿 dim 维度从 input 中取值。
函数定义
torch.take_along_dim(input, indices, dim)
参数:
input(Tensor): 输入张量。indices(Tensor): 索引张量,指定要获取的元素位置。形状必须与 input 在 dim 维度上兼容。dim(int): 要沿着的维度。
返回值:
torch.Tensor: 返回按索引取出的元素组成的新张量。
使用示例
实例
import torch
# 创建一个2D张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print("原始张量:")
print(x)
# 沿 dim=1 取元素
indices = torch.tensor([[0, 1, 2],
[2, 1, 0],
[0, 0, 0]])
y = torch.take_along_dim(x, indices, dim=1)
print("n索引:")
print(indices)
print("n沿 dim=1 取出的元素:")
print(y)
# 创建一个2D张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print("原始张量:")
print(x)
# 沿 dim=1 取元素
indices = torch.tensor([[0, 1, 2],
[2, 1, 0],
[0, 0, 0]])
y = torch.take_along_dim(x, indices, dim=1)
print("n索引:")
print(indices)
print("n沿 dim=1 取出的元素:")
print(y)
输出结果为:
原始张量:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
索引:
tensor([[0, 1, 2],
[2, 1, 0],
[0, 0, 0]])
沿 dim=1 取出的元素:
tensor([[1, 2, 3],
[6, 5, 4],
[7, 7, 7]])
实例
import torch
# 沿 dim=0 取元素
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 每列取不同的行
indices = torch.tensor([[0, 1, 2],
[2, 0, 1],
[1, 2, 0]])
y = torch.take_along_dim(x, indices, dim=0)
print("原始张量:")
print(x)
print("n索引:")
print(indices)
print("n沿 dim=0 取出的元素:")
print(y)
# 沿 dim=0 取元素
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 每列取不同的行
indices = torch.tensor([[0, 1, 2],
[2, 0, 1],
[1, 2, 0]])
y = torch.take_along_dim(x, indices, dim=0)
print("原始张量:")
print(x)
print("n索引:")
print(indices)
print("n沿 dim=0 取出的元素:")
print(y)
输出结果为:
原始张量:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
索引:
tensor([[0, 1, 2],
[2, 0, 1],
[1, 2, 0]])
沿 dim=0 取出的元素:
tensor([[1, 5, 9],
[7, 2, 6],
[4, 8, 3]])
</p>
<div class="example">
<h2 class="example">实例</h2>
<div class="example_code">
import torch
# 在3D张量上使用
x = torch.arange(24).reshape(2, 3, 4)
print("原始形状:", x.shape)
# 沿 dim=1 取元素
indices = torch.tensor([[0, 1, 2],
[2, 0, 1]])
y = torch.take_along_dim(x, indices, dim=1)
print("索引形状:", indices.shape)
print("结果形状:", y.shape)
print("n结果:")
print(y)
</div>
</div>
<p>输出结果为:</p>
<pre>
原始形状: torch.Size([2, 3, 4])
索引形状: torch.Size([2, 3])
结果形状: torch.Size([2, 3, 4])
结果:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
实例
import torch
# 应用:沿批次维度选择特定的元素
# 比如在注意力机制中选择特定的key-value
batch_size = 2
num_heads = 3
seq_len = 4
head_dim = 5
# 模拟query的注意力权重
attn_weights = torch.randn(batch_size, num_heads, seq_len)
# 对每个head取top-k的索引
k = 2
indices = torch.argsort(attn_weights, dim=-1, descending=True)[..., :k]
print("索引形状:", indices.shape)
# 模拟value张量
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
# 沿seq_len维度取对应的value
selected_value = torch.take_along_dim(value, indices.unsqueeze(-1).expand(-1, -1, -1, head_dim), dim=2)
print("Value形状:", value.shape)
print("选择的Value形状:", selected_value.shape)
# 应用:沿批次维度选择特定的元素
# 比如在注意力机制中选择特定的key-value
batch_size = 2
num_heads = 3
seq_len = 4
head_dim = 5
# 模拟query的注意力权重
attn_weights = torch.randn(batch_size, num_heads, seq_len)
# 对每个head取top-k的索引
k = 2
indices = torch.argsort(attn_weights, dim=-1, descending=True)[..., :k]
print("索引形状:", indices.shape)
# 模拟value张量
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
# 沿seq_len维度取对应的value
selected_value = torch.take_along_dim(value, indices.unsqueeze(-1).expand(-1, -1, -1, head_dim), dim=2)
print("Value形状:", value.shape)
print("选择的Value形状:", selected_value.shape)
输出结果为:
索引形状: torch.Size([2, 3, 4]) Value形状: torch.Size([2, 3, 4, 5]) 选择的Value形状: torch.Size([2, 3, 2, 5])
注意:
torch.take_along_dim允许按维度进行索引,这比torch.take更加灵活,因为后者总是将张量视为一维的。

Pytorch torch 参考手册