PyTorch torch.dsplit 函数
torch.dsplit 是 PyTorch 中用于深度(沿第三维)分割张量的函数。
函数定义
torch.dsplit(input, indices_or_sections)
使用示例
实例
import torch
# 三维张量深度分割
x = torch.arange(24).reshape(2, 3, 4)
print("原始三维张量:")
print(x)
result = torch.dsplit(x, 2)
print("沿深度分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 按索引分割
result = torch.dsplit(x, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 四维张量深度分割
y = torch.arange(32).reshape(2, 2, 4, 2)
print("n四维张量形状:", y.shape)
result = torch.dsplit(y, 2)
print("沿第三维分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i} 形状: {t.shape}")
# 三维张量深度分割
x = torch.arange(24).reshape(2, 3, 4)
print("原始三维张量:")
print(x)
result = torch.dsplit(x, 2)
print("沿深度分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 按索引分割
result = torch.dsplit(x, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 四维张量深度分割
y = torch.arange(32).reshape(2, 2, 4, 2)
print("n四维张量形状:", y.shape)
result = torch.dsplit(y, 2)
print("沿第三维分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i} 形状: {t.shape}")
输出结果为:
原始三维张量:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
沿深度分为 2 份:
块 0:
tensor([[[ 0, 1],
[ 4, 5],
[ 8, 9]],
[[12, 13],
[16, 17],
[20, 21]]])
块 1:
tensor([[[ 2, 3],
[ 6, 7],
[10, 11]],
[[14, 15],
[18, 19],
[22, 23]]])
按索引 [1, 3] 分割:
块 0:
tensor([[[ 0],
[ 4],
[ 8]],
[[12],
[16],
[20]]])
块 1:
tensor([[[ 1, 2],
[ 5, 6],
[ 9, 10]],
[[13, 14],
[17, 18],
[21, 22]]])
块 2:
tensor([[[ 3],
[ 7],
[11]],
[[15],
[19],
[23]]])
四维张量形状: torch.Size([2, 2, 4, 2])
沿第三维分为 2 份:
块 0 形状: torch.Size([2, 2, 2, 2])
块 1 形状: torch.Size([2, 2, 2, 2])

Pytorch torch 参考手册