PyTorch torch.tensor_split 函数
torch.tensor_split 是 PyTorch 中用于按索引或段数分割张量的函数。
函数定义
torch.tensor_split(input, indices_or_sections, dim=0)
使用示例
实例
import torch
x = torch.arange(12)
print("原始张量:")
print(x)
# 按索引分割
result = torch.tensor_split(x, [2, 5, 8])
print("按索引 [2, 5, 8] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 按段数分割(平均分为3份)
y = torch.arange(9)
result = torch.tensor_split(y, 3)
print("n平均分为 3 份:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 对二维张量按列分割
z = torch.arange(12).reshape(3, 4)
print("n二维张量:")
print(z)
result = torch.tensor_split(z, 2, dim=1)
print("按列分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
x = torch.arange(12)
print("原始张量:")
print(x)
# 按索引分割
result = torch.tensor_split(x, [2, 5, 8])
print("按索引 [2, 5, 8] 分割:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 按段数分割(平均分为3份)
y = torch.arange(9)
result = torch.tensor_split(y, 3)
print("n平均分为 3 份:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 对二维张量按列分割
z = torch.arange(12).reshape(3, 4)
print("n二维张量:")
print(z)
result = torch.tensor_split(z, 2, dim=1)
print("按列分为 2 份:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
输出结果为:
原始张量:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
按索引 [2, 5, 8] 分割:
块 0: tensor([0, 1])
块 1: tensor([2, 3, 4])
块 2: tensor([5, 6, 7])
块 3: tensor([ 8, 9, 10, 11])
平均分为 3 份:
块 0: tensor([0, 1, 2])
块 1: tensor([3, 4, 5])
块 2: tensor([6, 7, 8])
二维张量:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
按列分为 2 份:
块 0:
tensor([[0, 1],
[4, 5],
[8, 9]])
块 1:
tensor([[ 2, 3],
[ 6, 7],
[10, 11]])

Pytorch torch 参考手册