PyTorch torch.unsqueeze 函数
torch.unsqueeze 是 PyTorch 中用于在指定位置插入大小为 1 的维度的函数。
函数定义
torch.unsqueeze(input, dim)
使用示例
实例
import torch
x = torch.randn(2, 3)
print("原始形状:", x.shape)
# 在 dim=0 插入维度
y = torch.unsqueeze(x, dim=0)
print("dim=0 扩展后:", y.shape)
# 在 dim=1 插入维度
z = torch.unsqueeze(x, dim=1)
print("dim=1 扩展后:", z.shape)
x = torch.randn(2, 3)
print("原始形状:", x.shape)
# 在 dim=0 插入维度
y = torch.unsqueeze(x, dim=0)
print("dim=0 扩展后:", y.shape)
# 在 dim=1 插入维度
z = torch.unsqueeze(x, dim=1)
print("dim=1 扩展后:", z.shape)
输出结果为:
原始形状: torch.Size([2, 3]) dim=0 扩展后: torch.Size([1, 2, 3]) dim=1 扩展后: torch.Size([2, 1, 3])

Pytorch torch 参考手册