现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.squeeze 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.squeeze 是 PyTorch 中用于移除张量中大小为 1 的维度的函数。

函数定义

torch.squeeze(input, dim)

使用示例

实例

import torch

# 创建有大小为 1 的维度的张量
x = torch.randn(1, 2, 1, 3, 1)
print("原始形状:", x.shape)

# 移除所有大小为 1 的维度
y = torch.squeeze(x)
print("挤压后形状:", y.shape)

# 指定维度
z = torch.squeeze(x, dim=0)
print("dim=0 挤压后:", z.shape)

输出结果为:

原始形状: torch.Size([1, 2, 1, 3, 1])
挤压后形状: torch.Size([2, 3])
dim=0 挤压后: torch.Size([2, 1, 3, 1])

Pytorch torch 参考手册 Pytorch torch 参考手册