PyTorch torch.unravel_index 函数
torch.unravel_index 是 PyTorch 中用于将展平的索引展开为多维索引的函数。与 torch.ravel_index 相反,它将一个在展平数组中的位置转换回多维数组中的坐标。
函数定义
torch.unravel_index(indices, shape)
参数说明
indices: 展平后的索引shape: 多维数组的形状
使用示例
实例
import torch
# 展平索引
indices = torch.tensor([0, 1, 5, 6, 7])
# 数组形状
shape = (2, 4)
# 展开为多维索引
result = torch.unravel_index(indices, shape)
print("展平索引:", indices)
print("数组形状:", shape)
print("多维索引:")
print(result)
# 展平索引
indices = torch.tensor([0, 1, 5, 6, 7])
# 数组形状
shape = (2, 4)
# 展开为多维索引
result = torch.unravel_index(indices, shape)
print("展平索引:", indices)
print("数组形状:", shape)
print("多维索引:")
print(result)
输出结果为:
展平索引: tensor([0, 1, 5, 6, 7])
数组形状: (2, 4)
多维索引:
tensor([[0, 0, 1, 1, 1],
[0, 1, 1, 0, 1]])

Pytorch torch 参考手册