现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.unflatten 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.unflatten 是 PyTorch 中用于展开维度的函数。它将张量中已存在的维度展开为多个维度,用于重构张量形状。

函数定义

torch.unflatten(input, dim, sizes)

参数说明:

  • input: 输入张量
  • dim: 要展开的维度索引
  • sizes: 展开后的尺寸元组

使用示例

实例

import torch

# 创建一维张量
x = torch.arange(12)

# 展开为 3x4 矩阵
y = torch.unflatten(x, dim=0, sizes=(3, 4))
print(y.shape)
print(y)

输出结果为:

torch.Size([3, 4])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

实例

import torch

# 创建展平的张量
x = torch.randn(24)

# 展开为 2x3x4 三维张量
y = torch.unflatten(x, dim=0, sizes=(2, 3, 4))
print(y.shape)

输出结果为:

torch.Size([2, 3, 4])

实例

import torch

# 创建展平的张量
x = torch.randn(12)

# 可以用名称指定维度
x = x.rename('N')
y = torch.unflatten(x, dim='N', sizes=(3, 4))
print(y.shape)

输出结果为:

torch.Size([3, 4])

Pytorch torch 参考手册 Pytorch torch 参考手册