现在位置: 首页 > PyTorch 教程 > 正文

PyTorch torch.nn.Conv1d 函数

PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册


torch.nn.Conv1d 是 PyTorch 中的一维卷积模块。

主要用于处理序列数据,如文本、音频和时间序列信号。

函数定义

torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

参数说明

  • in_channels: 输入通道数
  • out_channels: 输出通道数
  • kernel_size: 卷积核大小

使用示例

示例 1: 基本用法

实例

import torch
import torch.nn as nn

# 一维卷积:输入通道=3,输出通道=64,核大小=3
conv1d = nn.Conv1d(in_channels=3, out_channels=64, kernel_size=3, padding=1)

# 输入:batch=4,通道=3,序列长度=100
x = torch.randn(4, 3, 100)

output = conv1d(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 2: 文本分类

实例

import torch
import torch.nn as nn

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super(TextCNN, self).__init__()
        # 嵌入层: (batch, seq) -> (batch, embed, seq)
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # 多个不同大小的卷积核
        self.convs = nn.ModuleList([
            nn.Conv1d(embed_dim, 128, kernel_size=k)
            for k in [2, 3, 4, 5]
        ])
        self.fc = nn.Linear(128 * 4, num_classes)

    def forward(self, x):
        # x: (batch, seq)
        x = self.embedding(x)  # (batch, seq, embed)
        x = x.permute(0, 2, 1)  # (batch, embed, seq)

        # 卷积 + ReLU + GlobalMaxPool
        pooled = []
        for conv in self.convs:
            c = conv(x)  # (batch, 128, seq')
            c = nn.functional.relu(c)
            c = c.max(dim=2)[0]  # Global max pooling
            pooled.append(c)

        x = torch.cat(pooled, dim=1)
        return self.fc(x)

# 测试
model = TextCNN(vocab_size=10000, embed_dim=128, num_classes=2)
x = torch.randint(0, 10000, (4, 50))  # batch=4, seq=50
output = model(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)

示例 3: 音频处理

实例

import torch
import torch.nn as nn

# 音频特征提取
conv_audio = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5, stride=2)

# 单通道音频:batch=4, 通道=1, 采样点=16000
audio = torch.randn(4, 1, 16000)

output = conv_audio(audio)

print("输入形状:", audio.shape)
print("输出形状:", output.shape)
print("输出长度:", output.shape[2])

使用场景

  • 文本分类: TextCNN
  • 音频处理: 语音特征提取
  • 时间序列: 信号滤波

注意:输入张量形状为 (batch, channels, length)。


PyTorch torch.nn 参考手册 PyTorch torch.nn 参考手册