PyTorch torch.unflatten 函数
torch.unflatten 是 PyTorch 中用于展开维度的函数。它将张量中已存在的维度展开为多个维度,用于重构张量形状。
函数定义
torch.unflatten(input, dim, sizes)
参数说明:
input: 输入张量dim: 要展开的维度索引sizes: 展开后的尺寸元组
使用示例
实例
import torch
# 创建一维张量
x = torch.arange(12)
# 展开为 3x4 矩阵
y = torch.unflatten(x, dim=0, sizes=(3, 4))
print(y.shape)
print(y)
# 创建一维张量
x = torch.arange(12)
# 展开为 3x4 矩阵
y = torch.unflatten(x, dim=0, sizes=(3, 4))
print(y.shape)
print(y)
输出结果为:
torch.Size([3, 4])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
实例
import torch
# 创建展平的张量
x = torch.randn(24)
# 展开为 2x3x4 三维张量
y = torch.unflatten(x, dim=0, sizes=(2, 3, 4))
print(y.shape)
# 创建展平的张量
x = torch.randn(24)
# 展开为 2x3x4 三维张量
y = torch.unflatten(x, dim=0, sizes=(2, 3, 4))
print(y.shape)
输出结果为:
torch.Size([2, 3, 4])
实例
import torch
# 创建展平的张量
x = torch.randn(12)
# 可以用名称指定维度
x = x.rename('N')
y = torch.unflatten(x, dim='N', sizes=(3, 4))
print(y.shape)
# 创建展平的张量
x = torch.randn(12)
# 可以用名称指定维度
x = x.rename('N')
y = torch.unflatten(x, dim='N', sizes=(3, 4))
print(y.shape)
输出结果为:
torch.Size([3, 4])

Pytorch torch 参考手册