现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.unsqueeze 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.unsqueeze 是 PyTorch 中用于在指定位置插入大小为 1 的维度的函数。

函数定义

torch.unsqueeze(input, dim)

使用示例

实例

import torch

x = torch.randn(2, 3)
print("原始形状:", x.shape)

# 在 dim=0 插入维度
y = torch.unsqueeze(x, dim=0)
print("dim=0 扩展后:", y.shape)

# 在 dim=1 插入维度
z = torch.unsqueeze(x, dim=1)
print("dim=1 扩展后:", z.shape)

输出结果为:

原始形状: torch.Size([2, 3])
dim=0 扩展后: torch.Size([1, 2, 3])
dim=1 扩展后: torch.Size([2, 1, 3])

Pytorch torch 参考手册 Pytorch torch 参考手册