PyTorch torch.concat 函数
torch.concat 是 PyTorch 中用于沿指定维度连接多个张量的函数。它与 torch.cat 是同一个函数,用于将多个张量沿指定维度拼接成一个更大的张量。
函数定义
torch.concat(tensors, dim=0, out=None)
参数:
tensors(Sequence of Tensor): 要连接的张量序列。所有张量必须在除连接维度外的其他维度上形状相同。dim(int, 可选): 连接的维度,默认为 0。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回连接后的张量。
使用示例
实例
import torch
# 创建两个张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 连接两个张量
result = torch.concat([a, b])
print(result)
# 创建两个张量
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# 连接两个张量
result = torch.concat([a, b])
print(result)
输出结果为:
tensor([1, 2, 3, 4, 5, 6])
实例
import torch
# 创建两个二维张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 沿第一维度连接
c = torch.concat([a, b], dim=0)
# 沿第二维度连接
d = torch.concat([a, b], dim=1)
print("沿 dim=0 连接:")
print(c)
print("n沿 dim=1 连接:")
print(d)
# 创建两个二维张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 沿第一维度连接
c = torch.concat([a, b], dim=0)
# 沿第二维度连接
d = torch.concat([a, b], dim=1)
print("沿 dim=0 连接:")
print(c)
print("n沿 dim=1 连接:")
print(d)
输出结果为:
沿 dim=0 连接:
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
沿 dim=1 连接:
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
注意:torch.concat 是 torch.cat 的别名,两者的功能完全相同。

Pytorch torch 参考手册