PyTorch torch.squeeze 函数
torch.squeeze 是 PyTorch 中用于移除张量中大小为 1 的维度的函数。
函数定义
torch.squeeze(input, dim)
使用示例
实例
import torch
# 创建有大小为 1 的维度的张量
x = torch.randn(1, 2, 1, 3, 1)
print("原始形状:", x.shape)
# 移除所有大小为 1 的维度
y = torch.squeeze(x)
print("挤压后形状:", y.shape)
# 指定维度
z = torch.squeeze(x, dim=0)
print("dim=0 挤压后:", z.shape)
# 创建有大小为 1 的维度的张量
x = torch.randn(1, 2, 1, 3, 1)
print("原始形状:", x.shape)
# 移除所有大小为 1 的维度
y = torch.squeeze(x)
print("挤压后形状:", y.shape)
# 指定维度
z = torch.squeeze(x, dim=0)
print("dim=0 挤压后:", z.shape)
输出结果为:
原始形状: torch.Size([1, 2, 1, 3, 1]) 挤压后形状: torch.Size([2, 3]) dim=0 挤压后: torch.Size([2, 1, 3, 1])

Pytorch torch 参考手册