PyTorch torch.flatten 函数
torch.flatten 是 PyTorch 中用于将张量展平的函数。
函数定义
torch.flatten(input, start_dim, end_dim, out)
使用示例
实例
import torch
x = torch.randn(2, 3, 4)
# 展平为一维
y = torch.flatten(x)
print("原始形状:", x.shape)
print("展平后:", y.shape)
# 从指定维度展平
z = torch.flatten(x, start_dim=1)
print("从 dim=1 展平:", z.shape)
x = torch.randn(2, 3, 4)
# 展平为一维
y = torch.flatten(x)
print("原始形状:", x.shape)
print("展平后:", y.shape)
# 从指定维度展平
z = torch.flatten(x, start_dim=1)
print("从 dim=1 展平:", z.shape)
输出结果为:
原始形状: torch.Size([2, 3, 4]) 展平后: torch.Size([24]) 从 dim=1 展平: torch.Size([2, 12])

Pytorch torch 参考手册