现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.dstack 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.dstack 是 PyTorch 中用于深度(沿第三维)堆叠张量的函数。

函数定义

torch.dstack(tensors, *, out=None)

使用示例

实例

import torch

# 二维张量深度堆叠
x1 = torch.tensor([[1, 2], [3, 4]])
x2 = torch.tensor([[5, 6], [7, 8]])
result = torch.dstack([x1, x2])
print("二维张量深度堆叠:")
print(f"  x1:n{x1}")
print(f"  x2:n{x2}")
print(f"  dstack:n{result}")

# 三维张量深度堆叠
y1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
y2 = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
result = torch.dstack([y1, y2])
print("n三维张量深度堆叠:")
print(f"  y1:n{y1}")
print(f"  y2:n{y2}")
print(f"  dstack:n{result}")
print(f"  dstack 形状: {result.shape}")

# 一维张量深度堆叠
z1 = torch.tensor([1, 2, 3])
z2 = torch.tensor([4, 5, 6])
result = torch.dstack([z1, z2])
print("n一维张量深度堆叠:")
print(f"  z1: {z1}")
print(f"  z2: {z2}")
print(f"  dstack:n{result}")

输出结果为:

二维张量深度堆叠:
  x1:
tensor([[1, 2],
        [3, 4]])
  x2:
tensor([[5, 6],
        [7, 8]])
  dstack:
tensor([[[ 1,  5],
         [ 2,  6]],

        [[ 3,  7],
         [ 4,  8]]])

三维张量深度堆叠:
  y1:
tensor([[[ 1,  2],
         [ 3,  4]],

        [[ 5,  6],
         [ 7,  8]]])
  y2:
tensor([[[ 9, 10],
         [11, 12]],

        [[13, 14],
         [15, 16]]])
  dstack:
tensor([[[ 1,  2,  9, 10],
         [ 3,  4, 11, 12]],

        [[ 5,  6, 13, 14],
         [ 7,  8, 15, 16]]])
  dstack 形状: torch.Size([2, 2, 4])

一维张量深度堆叠:
  z1: tensor([1, 2, 3])
  z2: tensor([4, 5, 6])
  dstack:
tensor([[1, 4],
        [2, 5],
        [3, 6]])

Pytorch torch 参考手册 Pytorch torch 参考手册