现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.unravel_index 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.unravel_index 是 PyTorch 中用于将展平的索引展开为多维索引的函数。与 torch.ravel_index 相反,它将一个在展平数组中的位置转换回多维数组中的坐标。

函数定义

torch.unravel_index(indices, shape)

参数说明

  • indices: 展平后的索引
  • shape: 多维数组的形状

使用示例

实例

import torch

# 展平索引
indices = torch.tensor([0, 1, 5, 6, 7])

# 数组形状
shape = (2, 4)

# 展开为多维索引
result = torch.unravel_index(indices, shape)

print("展平索引:", indices)
print("数组形状:", shape)
print("多维索引:")
print(result)

输出结果为:

展平索引: tensor([0, 1, 5, 6, 7])
数组形状: (2, 4)
多维索引:
tensor([[0, 0, 1, 1, 1],
        [0, 1, 1, 0, 1]])

Pytorch torch 参考手册 Pytorch torch 参考手册