PyTorch torch.vsplit 函数
torch.vsplit 是 PyTorch 中用于垂直(沿行)分割张量的函数。
函数定义
torch.vsplit(input, indices_or_sections)
使用示例
实例
import torch
# 二维张量垂直分割
x = torch.arange(12).reshape(4, 3)
print("原始二维张量:")
print(x)
result = torch.vsplit(x, 2)
print("沿行平均分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 按索引分割
result = torch.vsplit(x, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 三维张量垂直分割
y = torch.arange(24).reshape(4, 3, 2)
print("n三维张量:")
print(y)
result = torch.vsplit(y, 2)
print("n沿第一维分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}: {t.shape}")
# 二维张量垂直分割
x = torch.arange(12).reshape(4, 3)
print("原始二维张量:")
print(x)
result = torch.vsplit(x, 2)
print("沿行平均分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 按索引分割
result = torch.vsplit(x, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 三维张量垂直分割
y = torch.arange(24).reshape(4, 3, 2)
print("n三维张量:")
print(y)
result = torch.vsplit(y, 2)
print("n沿第一维分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}: {t.shape}")
输出结果为:
原始二维张量:
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
沿行平均分为 2 份:
块 0:
tensor([[0, 1, 2],
[3, 4, 5]])
块 1:
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
按索引 [1, 3] 分割:
块 0:
tensor([[0, 1, 2]])
块 1:
tensor([[3, 4, 5],
[6, 7, 8]])
块 2:
tensor([[ 9, 10, 11]])
三维张量:
tensor([[[ 0, 1],
[ 3, 4],
[ 5, 6]],
[[ 7, 8],
[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16],
[17, 18]],
[[19, 20],
[21, 22],
[23, 24]]])
沿第一维分为 2 份:
块 0: torch.Size([2, 3, 2])
块 1: torch.Size([2, 3, 2])

Pytorch torch 参考手册