PyTorch torch.split_with_sizes 函数
torch.split_with_sizes 是 PyTorch 中用于按指定大小分割张量的函数。
函数定义
torch.split_with_sizes(input, split_sizes, dim=0)
使用示例
实例
import torch
x = torch.arange(10)
print("原始张量:")
print(x)
# 按大小 [2, 3, 5] 分割
result = torch.split_with_sizes(x, [2, 3, 5])
print("分割结果:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 对二维张量按行分割
y = torch.arange(12).reshape(4, 3)
print("n原始二维张量:")
print(y)
result = torch.split_with_sizes(y, [1, 2, 1], dim=0)
print("按行分割 [1, 2, 1]:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
x = torch.arange(10)
print("原始张量:")
print(x)
# 按大小 [2, 3, 5] 分割
result = torch.split_with_sizes(x, [2, 3, 5])
print("分割结果:")
for i, t in enumerate(result):
print(f" 块 {i}: {t}")
# 对二维张量按行分割
y = torch.arange(12).reshape(4, 3)
print("n原始二维张量:")
print(y)
result = torch.split_with_sizes(y, [1, 2, 1], dim=0)
print("按行分割 [1, 2, 1]:")
for i, t in enumerate(result):
print(f" 块 {i}:n{t}")
输出结果为:
原始张量:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
分割结果:
块 0: tensor([0, 1])
块 1: tensor([2, 3, 4])
块 2: tensor([5, 6, 7, 8, 9])
原始二维张量:
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
按行分割 [1, 2, 1]:
块 0:
tensor([[0, 1, 2]])
块 1:
tensor([[3, 4, 5],
[6, 7, 8]])
块 2:
tensor([[ 9, 10, 11]])

Pytorch torch 参考手册