现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.dsplit 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.dsplit 是 PyTorch 中用于深度(沿第三维)分割张量的函数。

函数定义

torch.dsplit(input, indices_or_sections)

使用示例

实例

import torch

# 三维张量深度分割
x = torch.arange(24).reshape(2, 3, 4)
print("原始三维张量:")
print(x)

result = torch.dsplit(x, 2)
print("沿深度分为 2 份:")
for i, t in enumerate(result):
    print(f"  块 {i}:n{t}")

# 按索引分割
result = torch.dsplit(x, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
    print(f"  块 {i}:n{t}")

# 四维张量深度分割
y = torch.arange(32).reshape(2, 2, 4, 2)
print("n四维张量形状:", y.shape)

result = torch.dsplit(y, 2)
print("沿第三维分为 2 份:")
for i, t in enumerate(result):
    print(f"  块 {i} 形状: {t.shape}")

输出结果为:

原始三维张量:
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
沿深度分为 2 份:
  块 0:
tensor([[[ 0,  1],
         [ 4,  5],
         [ 8,  9]],

        [[12, 13],
         [16, 17],
         [20, 21]]])
  块 1:
tensor([[[ 2,  3],
         [ 6,  7],
         [10, 11]],

        [[14, 15],
         [18, 19],
         [22, 23]]])

按索引 [1, 3] 分割:
  块 0:
tensor([[[ 0],
         [ 4],
         [ 8]],

        [[12],
         [16],
         [20]]])
  块 1:
tensor([[[ 1,  2],
         [ 5,  6],
         [ 9, 10]],

        [[13, 14],
         [17, 18],
         [21, 22]]])
  块 2:
tensor([[[ 3],
         [ 7],
         [11]],

        [[15],
         [19],
         [23]]])

四维张量形状: torch.Size([2, 2, 4, 2])
沿第三维分为 2 份:
  块 0 形状: torch.Size([2, 2, 2, 2])
  块 1 形状: torch.Size([2, 2, 2, 2])

Pytorch torch 参考手册 Pytorch torch 参考手册