现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.stack 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.stack 是 PyTorch 中用于沿新维度堆叠多个张量的函数。它会创建一个新的维度,并将所有输入张量放在这个新维度上。

这在深度学习中常用于创建批量数据、堆叠多个模型的输出等场景。

函数定义

torch.stack(tensors, dim=0, out=None)

参数:

  • tensors (Sequence of Tensor): 要堆叠的张量序列。所有张量必须具有相同的形状。
  • dim (int, 可选): 堆叠的维度,默认为 0。新维度会被插入到这个位置。
  • out (Tensor, 可选): 输出张量。

返回值:

  • torch.Tensor: 返回堆叠后的张量。

使用示例

示例 1: 基础堆叠

实例

import torch

# 创建两个一维张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

# 沿新维度堆叠
c = torch.stack([a, b])

print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)
print(c)

输出结果为:

a 的形状: torch.Size([3])
b 的形状: torch.Size([3])
c 的形状: torch.Size([2, 3])
tensor([[1, 2, 3],
        [4, 5, 6]])

示例 2: 指定堆叠维度

实例

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

# 沿 dim=1 堆叠
c = torch.stack([a, b], dim=1)

print("c 的形状:", c.shape)
print(c)

输出结果为:

c 的形状: torch.Size([3, 2])
tensor([[1, 4],
        [2, 5],
        [3, 6]])

示例 3: 堆叠多个张量

实例

import torch

# 创建多个张量
tensors = [torch.tensor([i, i+1, i+2]) for i in range(5)]

# 堆叠所有张量
result = torch.stack(tensors)

print("结果形状:", result.shape)
print(result)

输出结果为:

结果形状: torch.Size([5, 3])
tensor([[0, 1, 2],
        [1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [4, 5, 6]])

示例 4: 在神经网络中保存多个状态

实例

import torch

# 模拟保存多个 epoch 的损失值
losses = []
for epoch in range(5):
    loss = torch.tensor([epoch * 0.1, (epoch + 1) * 0.1])
    losses.append(loss)

# 堆叠所有 epoch 的损失
all_losses = torch.stack(losses)

print("各 epoch 损失形状:", all_losses.shape)
print(all_losses)

输出结果为:

各 epoch 损失形状: torch.Size([5, 2])
tensor([[0.0000, 0.1000],
        [0.1000, 0.2000],
        [0.2000, 0.3000],
        [0.3000, 0.4000],
        [0.4000, 0.5000]])

torch.stack 与 torch.cat 的区别

实例

import torch

a = torch.randn(2, 3)
b = torch.randn(2, 3)

# stack 会创建新维度
stack_result = torch.stack([a, b])
print("stack 结果形状:", stack_result.shape)

# cat 不会创建新维度
cat_result = torch.cat([a, b], dim=0)
print("cat 结果形状:", cat_result.shape)

输出结果为:

stack 结果形状: torch.Size([2, 2, 3])
cat 结果形状: torch.Size([4, 3])
  • torch.stack: 沿新维度堆叠,输入张量形状必须完全相同,会增加一个新维度。
  • torch.cat: 沿现有维度连接,输入张量在连接维度可以不同,不会增加新维度。

Pytorch torch 参考手册 Pytorch torch 参考手册