PyTorch torch.hsplit 函数
torch.hsplit 是 PyTorch 中用于水平(沿列)分割张量的函数。
函数定义
torch.hsplit(input, indices_or_sections)
使用示例
实例
import torch
# 一维张量水平分割
x = torch.arange(10)
print("原始一维张量:")
print(x)
result = torch.hsplit(x, 2)
print("平均分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 二维张量水平分割
y = torch.arange(12).reshape(3, 4)
print("n原始二维张量:")
print(y)
result = torch.hsplit(y, 2)
print("沿列分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 按索引分割
result = torch.hsplit(y, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 一维张量水平分割
x = torch.arange(10)
print("原始一维张量:")
print(x)
result = torch.hsplit(x, 2)
print("平均分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 二维张量水平分割
y = torch.arange(12).reshape(3, 4)
print("n原始二维张量:")
print(y)
result = torch.hsplit(y, 2)
print("沿列分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
# 按索引分割
result = torch.hsplit(y, [1, 3])
print("n按索引 [1, 3] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
输出结果为:
原始一维张量:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
平均分为 2 份:
块 0: tensor([0, 1, 2, 3, 4])
块 1: tensor([5, 6, 7, 8, 9])
原始二维张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
沿列分为 2 份:
块 0:
tensor([[0, 1],
[4, 5],
[8, 9]])
块 1:
tensor([[ 2, 3],
[ 6, 7],
[10, 11]])
按索引 [1, 3] 分割:
块 0:
tensor([[0],
[4],
[8]])
块 1:
tensor([[ 1, 2],
[ 5, 6],
[ 9, 10]])
块 2:
tensor([[ 3],
[ 7],
[11]])

Pytorch torch 参考手册