PyTorch torch.stack 函数
torch.stack 是 PyTorch 中用于沿新维度堆叠多个张量的函数。它会创建一个新的维度,并将所有输入张量放在这个新维度上。
这在深度学习中常用于创建批量数据、堆叠多个模型的输出等场景。
函数定义
torch.stack(tensors, dim=0, out=None)
参数:
tensors(Sequence of Tensor): 要堆叠的张量序列。所有张量必须具有相同的形状。dim(int, 可选): 堆叠的维度,默认为 0。新维度会被插入到这个位置。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回堆叠后的张量。
使用示例
示例 1: 基础堆叠
实例
import torch
# 创建两个一维张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 沿新维度堆叠
c = torch.stack([a, b])
print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)
print(c)
# 创建两个一维张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 沿新维度堆叠
c = torch.stack([a, b])
print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)
print(c)
输出结果为:
a 的形状: torch.Size([3])
b 的形状: torch.Size([3])
c 的形状: torch.Size([2, 3])
tensor([[1, 2, 3],
[4, 5, 6]])
示例 2: 指定堆叠维度
实例
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 沿 dim=1 堆叠
c = torch.stack([a, b], dim=1)
print("c 的形状:", c.shape)
print(c)
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 沿 dim=1 堆叠
c = torch.stack([a, b], dim=1)
print("c 的形状:", c.shape)
print(c)
输出结果为:
c 的形状: torch.Size([3, 2])
tensor([[1, 4],
[2, 5],
[3, 6]])
示例 3: 堆叠多个张量
实例
import torch
# 创建多个张量
tensors = [torch.tensor([i, i+1, i+2]) for i in range(5)]
# 堆叠所有张量
result = torch.stack(tensors)
print("结果形状:", result.shape)
print(result)
# 创建多个张量
tensors = [torch.tensor([i, i+1, i+2]) for i in range(5)]
# 堆叠所有张量
result = torch.stack(tensors)
print("结果形状:", result.shape)
print(result)
输出结果为:
结果形状: torch.Size([5, 3])
tensor([[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[4, 5, 6]])
示例 4: 在神经网络中保存多个状态
实例
import torch
# 模拟保存多个 epoch 的损失值
losses = []
for epoch in range(5):
loss = torch.tensor([epoch * 0.1, (epoch + 1) * 0.1])
losses.append(loss)
# 堆叠所有 epoch 的损失
all_losses = torch.stack(losses)
print("各 epoch 损失形状:", all_losses.shape)
print(all_losses)
# 模拟保存多个 epoch 的损失值
losses = []
for epoch in range(5):
loss = torch.tensor([epoch * 0.1, (epoch + 1) * 0.1])
losses.append(loss)
# 堆叠所有 epoch 的损失
all_losses = torch.stack(losses)
print("各 epoch 损失形状:", all_losses.shape)
print(all_losses)
输出结果为:
各 epoch 损失形状: torch.Size([5, 2])
tensor([[0.0000, 0.1000],
[0.1000, 0.2000],
[0.2000, 0.3000],
[0.3000, 0.4000],
[0.4000, 0.5000]])
torch.stack 与 torch.cat 的区别
实例
import torch
a = torch.randn(2, 3)
b = torch.randn(2, 3)
# stack 会创建新维度
stack_result = torch.stack([a, b])
print("stack 结果形状:", stack_result.shape)
# cat 不会创建新维度
cat_result = torch.cat([a, b], dim=0)
print("cat 结果形状:", cat_result.shape)
a = torch.randn(2, 3)
b = torch.randn(2, 3)
# stack 会创建新维度
stack_result = torch.stack([a, b])
print("stack 结果形状:", stack_result.shape)
# cat 不会创建新维度
cat_result = torch.cat([a, b], dim=0)
print("cat 结果形状:", cat_result.shape)
输出结果为:
stack 结果形状: torch.Size([2, 2, 3]) cat 结果形状: torch.Size([4, 3])
torch.stack: 沿新维度堆叠,输入张量形状必须完全相同,会增加一个新维度。torch.cat: 沿现有维度连接,输入张量在连接维度可以不同,不会增加新维度。

Pytorch torch 参考手册