PyTorch torch.cat 函数
torch.cat 是 PyTorch 中用于沿指定维度连接多个张量的函数。它会将多个张量沿指定维度拼接成一个更大的张量。
这是深度学习中非常常用的操作,例如在连接特征图、拼接数据批次等场景。
函数定义
torch.cat(tensors, dim=0, out=None)
参数:
tensors(Sequence of Tensor): 要连接的张量序列。所有张量必须在除连接维度外的其他维度上形状相同。dim(int, 可选): 连接的维度,默认为 0。out(Tensor, 可选): 输出张量。
返回值:
torch.Tensor: 返回连接后的张量。
使用示例
示例 1: 沿第一维度连接
实例
import torch
# 创建两个张量
a = torch.randn(2, 3)
b = torch.randn(2, 3)
# 沿第一维度(行)连接
c = torch.cat([a, b], dim=0)
print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)
print(c)
# 创建两个张量
a = torch.randn(2, 3)
b = torch.randn(2, 3)
# 沿第一维度(行)连接
c = torch.cat([a, b], dim=0)
print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)
print(c)
输出结果为:
a 的形状: torch.Size([2, 3])
b 的形状: torch.Size([2, 3])
c 的形状: torch.Size([4, 3])
tensor([[ 0.2532, 0.3643, 0.5341],
[ 0.9578, 0.9086, -0.2847],
[-0.7108, -0.0142, 0.7168],
[-0.1542, -0.9841, -1.4945]])
示例 2: 沿第二维度连接
实例
import torch
# 创建两个张量
a = torch.randn(2, 3)
b = torch.randn(2, 4)
# 沿第二维度(列)连接
c = torch.cat([a, b], dim=1)
print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)
# 创建两个张量
a = torch.randn(2, 3)
b = torch.randn(2, 4)
# 沿第二维度(列)连接
c = torch.cat([a, b], dim=1)
print("a 的形状:", a.shape)
print("b 的形状:", b.shape)
print("c 的形状:", c.shape)
输出结果为:
a 的形状: torch.Size([2, 3]) b 的形状: torch.Size([2, 4]) c 的形状: torch.Size([2, 7])
示例 3: 连接多个张量
实例
import torch
# 创建多个张量
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
c = torch.tensor([5, 6])
# 连接多个张量
result = torch.cat([a, b, c])
print(result)
# 创建多个张量
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
c = torch.tensor([5, 6])
# 连接多个张量
result = torch.cat([a, b, c])
print(result)
输出结果为:
tensor([1, 2, 3, 4, 5, 6]) </p> <h3>示例 4: 在神经网络中拼接特征</h3> <div class="example"> <h2 class="example">实例</h2> <div class="example_code"> import torch
# 模拟来自不同层的特征图
feature1 = torch.randn(1, 64, 32, 32) # 来自第一层的特征
feature2 = torch.randn(1, 128, 32, 32) # 来自第二层的特征
# 在通道维度(dim=1)拼接特征
combined = torch.cat([feature1, feature2], dim=1)
print("特征1 形状:", feature1.shape)
print("特征2 形状:", feature2.shape)
print("拼接后形状:", combined.shape)
</div> </div> <p>输出结果为:</p> <pre> 特征1 形状: torch.Size([1, 64, 32, 32]) 特征2 形状: torch.Size([1, 128, 32, 32]) 拼接后形状: torch.Size([1, 192, 32, 32])
在神经网络中,torch.cat 常用于特征金字塔(FPN)等结构中融合不同层的特征。
torch.cat 与 torch.stack 的区别
torch.cat: 沿现有维度连接,张量形状会在连接维度上相加。torch.stack: 沿新维度堆叠,会增加一个新的维度。
实例
import torch
a = torch.randn(2, 3)
b = torch.randn(2, 3)
# cat 和 stack 的区别
cat_result = torch.cat([a, b], dim=0)
stack_result = torch.stack([a, b], dim=0)
print("cat 结果形状:", cat_result.shape) # (4, 3)
print("stack 结果形状:", stack_result.shape) # (2, 2, 3)
a = torch.randn(2, 3)
b = torch.randn(2, 3)
# cat 和 stack 的区别
cat_result = torch.cat([a, b], dim=0)
stack_result = torch.stack([a, b], dim=0)
print("cat 结果形状:", cat_result.shape) # (4, 3)
print("stack 结果形状:", stack_result.shape) # (2, 2, 3)

Pytorch torch 参考手册