现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.vsplit 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.vsplit 是 PyTorch 中用于垂直(沿行)分割张量的函数。

函数定义

torch.vsplit(input, indices_or_sections)

使用示例

实例

import torch

# 二维张量垂直分割
x = torch.arange(12).reshape(4, 3)
print("原始二维张量:")
print(x)

result = torch.vsplit(x, 2)
print("沿行平均分为 2 份:")
for i, t in enumerate(result):
    print(f"  块 {i}:n{t}")

# 按索引分割
result = torch.vsplit(x, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
    print(f"  块 {i}:n{t}")

# 三维张量垂直分割
y = torch.arange(24).reshape(4, 3, 2)
print("n三维张量:")
print(y)

result = torch.vsplit(y, 2)
print("n沿第一维分为 2 份:")
for i, t in enumerate(result):
    print(f"  块 {i}: {t.shape}")

输出结果为:

原始二维张量:
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
沿行平均分为 2 份:
  块 0:
tensor([[0, 1, 2],
        [3, 4, 5]])
  块 1:
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])

按索引 [1, 3] 分割:
  块 0:
tensor([[0, 1, 2]])
  块 1:
tensor([[3, 4, 5],
        [6, 7, 8]])
  块 2:
tensor([[ 9, 10, 11]])

三维张量:
tensor([[[ 0,  1],
         [ 3,  4],
         [ 5,  6]],

        [[ 7,  8],
         [ 9, 10],
         [11, 12]],

        [[13, 14],
         [15, 16],
         [17, 18]],

        [[19, 20],
         [21, 22],
         [23, 24]]])
沿第一维分为 2 份:
  块 0: torch.Size([2, 3, 2])
  块 1: torch.Size([2, 3, 2])

Pytorch torch 参考手册 Pytorch torch 参考手册