现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.concatenate 函数


Pytorch torch 参考手册 Pytorch torch 参考手册

torch.concatenate 是 PyTorch 中用于沿指定维度连接多个张量的函数。它与 torch.cat 是同一个函数,用于将多个张量沿指定维度拼接成一个更大的张量。

函数定义

torch.concatenate(tensors, dim=0, out=None)

参数:

  • tensors (Sequence of Tensor): 要连接的张量序列。所有张量必须在除连接维度外的其他维度上形状相同。
  • dim (int, 可选): 连接的维度,默认为 0。
  • out (Tensor, 可选): 输出张量。

返回值:

  • torch.Tensor: 返回连接后的张量。

使用示例

实例

import torch

# 创建多个张量
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
c = torch.tensor([5, 6])

# 连接多个张量
result = torch.concatenate([a, b, c])

print(result)

输出结果为:

tensor([1, 2, 3, 4, 5, 6])

实例

import torch

# 创建两个二维张量
a = torch.randn(2, 3)
b = torch.randn(2, 3)

# 沿第一维度连接
c = torch.concatenate([a, b], dim=0)

print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)

输出结果为:

a 的形状: torch.Size([2, 3])
b 的形状: torch.Size([2, 3])
c 的形状: torch.Size([4, 3])

注意:torch.concatenatetorch.cat 的别名,两者的功能完全相同。在实际代码中,更常用的是 torch.cat


Pytorch torch 参考手册 Pytorch torch 参考手册