PyTorch torch.row_stack 函数
torch.row_stack 是 PyTorch 中用于按行堆叠张量的函数,相当于垂直堆叠。
函数定义
torch.row_stack(tensors, *, out=None)
使用示例
实例
import torch
# 一维张量行堆叠
x1 = torch.tensor([1, 2, 3])
x2 = torch.tensor([4, 5, 6])
result = torch.row_stack([x1, x2])
print("一维张量行堆叠:")
print(f" x1: {x1}")
print(f" x2: {x2}")
print(f" row_stack:n{result}")
# 二维张量行堆叠
y1 = torch.tensor([[1, 2, 3]])
y2 = torch.tensor([[4, 5, 6]])
result = torch.row_stack([y1, y2])
print("n二维张量行堆叠:")
print(f" y1:n{y1}")
print(f" y2:n{y2}")
print(f" row_stack:n{result}")
# 等价于 vstack
z1 = torch.tensor([1, 2, 3])
z2 = torch.tensor([4, 5, 6])
result_vstack = torch.vstack([z1, z2])
result_row_stack = torch.row_stack([z1, z2])
print("nrow_stack 等价于 vstack:")
print(f" vstack: {result_vstack.tolist()}")
print(f" row_stack: {result_row_stack.tolist()}")
# 一维张量行堆叠
x1 = torch.tensor([1, 2, 3])
x2 = torch.tensor([4, 5, 6])
result = torch.row_stack([x1, x2])
print("一维张量行堆叠:")
print(f" x1: {x1}")
print(f" x2: {x2}")
print(f" row_stack:n{result}")
# 二维张量行堆叠
y1 = torch.tensor([[1, 2, 3]])
y2 = torch.tensor([[4, 5, 6]])
result = torch.row_stack([y1, y2])
print("n二维张量行堆叠:")
print(f" y1:n{y1}")
print(f" y2:n{y2}")
print(f" row_stack:n{result}")
# 等价于 vstack
z1 = torch.tensor([1, 2, 3])
z2 = torch.tensor([4, 5, 6])
result_vstack = torch.vstack([z1, z2])
result_row_stack = torch.row_stack([z1, z2])
print("nrow_stack 等价于 vstack:")
print(f" vstack: {result_vstack.tolist()}")
print(f" row_stack: {result_row_stack.tolist()}")
输出结果为:
一维张量行堆叠:
x1: tensor([1, 2, 3])
x2: tensor([4, 5, 6])
row_stack:
tensor([[1, 2, 3],
[4, 5, 6]])
二维张量行堆叠:
y1:
tensor([[1, 2, 3]])
y2:
tensor([[4, 5, 6]])
row_stack:
tensor([[1, 2, 3],
[4, 5, 6]])
row_stack 等价于 vstack:
vstack: [[1, 2, 3], [4, 5, 6]]
row_stack: [[1, 2, 3], [4, 5, 6]]

Pytorch torch 参考手册