现在位置: 首页 > PyTorch 教程 > 正文

PyTorch 生成对抗网络 (GAN)

生成对抗网络(Generative Adversarial Network,GAN)是深度学习中最具创意的模型架构之一。它通过让两个神经网络相互对抗、相互学习,最终能够生成非常逼真的数据。GAN 广泛应用于图像生成、风格迁移、数据增强等场景。


1. GAN 核心原理

GAN 的核心思想来源于博弈论中的"零和博弈"。它包含两个相互对抗的网络:

  • 生成器(Generator):学习生成假数据,目标是让判别器无法区分生成数据与真实数据
  • 判别器(Discriminator):学习区分真实数据与生成数据,目标是尽可能准确判断

两者在训练过程中相互对抗、不断提升,最终达到纳什均衡状态。

1.1 GAN 的目标函数

GAN 的训练目标可以表示为以下minimax游戏:

\[ \min_G \max_D \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] \]

其中:

  • \(G\) 表示生成器网络
  • \(D\) 表示判别器网络
  • \(x\) 表示真实数据
  • \(z\) 表示随机噪声向量(通常服从标准正态分布)
  • \(G(z)\) 表示生成器根据噪声生成的假数据

1.2 训练过程解读

GAN 的训练分为两个阶段:

第一阶段:训练判别器

固定生成器,提升判别器的分辨能力:

\[ \max_D \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] \]

第二阶段:训练生成器

固定判别器,提升生成器的欺骗能力:

\[ \min_G \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] \]

实际训练中,通常先训练判别器 k 步,再训练生成器 1 步,以保持平衡。


2. 基础 GAN 实现

下面实现一个最简单的 GAN——用于生成二维数据点。

2.1 定义生成器和判别器

实例

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(42)

# ── 生成器网络 ──────────────────────────────────────
class Generator(nn.Module):
    """
    生成器:从随机噪声生成数据
    输入:噪声向量 (batch_size, noise_dim)
    输出:生成数据 (batch_size, data_dim)
    """

    def __init__(self, noise_dim, data_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, data_dim),
            # 输出不激活,GAN 会学习合适的分布
        )

    def forward(self, x):
        return self.net(x)


# ── 判别器网络 ──────────────────────────────────────
class Discriminator(nn.Module):
    """
    判别器:区分真实数据与生成数据
    输入:数据点 (batch_size, data_dim)
    输出:真实数据的概率 (batch_size, 1)
    """

    def __init__(self, data_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(data_dim, hidden_dim),
            nn.LeakyReLU(0.2),  # LeakyReLU 防止梯度消失
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # 输出概率
        )

    def forward(self, x):
        return self.net(x)


# 超参数
NOISE_DIM = 16
DATA_DIM = 2
HIDDEN_DIM = 64
BATCH_SIZE = 128

# 创建网络
generator = Generator(NOISE_DIM, DATA_DIM, HIDDEN_DIM)
discriminator = Discriminator(DATA_DIM, HIDDEN_DIM)

print(f"生成器参数量: {sum(p.numel() for p in generator.parameters()):,}")
print(f"判别器参数量: {sum(p.numel() for p in discriminator.parameters()):,}")

2.2 训练循环

实例

# ── 优化器 ──────────────────────────────────────
lr = 0.001
g_optimizer = optim.Adam(generator.parameters(), lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

# 损失函数:二分类交叉熵
criterion = nn.BCELoss()

# ── 训练数据:环形分布 ──────────────────────────
def generate_real_data(batch_size):
    """生成环形分布的真实数据"""
    angles = torch.rand(batch_size) * 2 * torch.pi
    radius = 1.0 + torch.randn(batch_size) * 0.1  # 半径约为 1
    x = radius * torch.cos(angles)
    y = radius * torch.sin(angles)
    return torch.stack([x, y], dim=1)


# ── 训练循环 ──────────────────────────────────────
NUM_EPOCHS = 1000
d_losses = []
g_losses = []

for epoch in range(NUM_EPOCHS):
    # 1. 训练判别器
    # 生成假数据
    noise = torch.randn(BATCH_SIZE, NOISE_DIM)
    fake_data = generator(noise).detach()  # detach 避免计算生成器梯度

    # 生成真实数据
    real_data = generate_real_data(BATCH_SIZE)

    # 判别器损失
    real_pred = discriminator(real_data)
    fake_pred = discriminator(fake_data)
    d_loss = criterion(real_pred, torch.ones_like(real_pred)) + \
             criterion(fake_pred, torch.zeros_like(fake_pred))

    # 更新判别器
    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    # 2. 训练生成器
    # 生成新一批假数据
    noise = torch.randn(BATCH_SIZE, NOISE_DIM)
    fake_data = generator(noise)

    # 生成器损失:让判别器认为生成的数据是真实的
    fake_pred = discriminator(fake_data)
    g_loss = criterion(fake_pred, torch.ones_like(fake_pred))

    # 更新生成器
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()

    # 记录损失
    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())

    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch+1:4d} | D_loss: {d_loss:.4f} | G_loss: {g_loss:.4f}")

print("训练完成!")

2.3 可视化生成结果

实例

# 生成数据并可视化
def visualize_results(generator, num_samples=1000):
    noise = torch.randn(num_samples, NOISE_DIM)
    generated_data = generator(noise).detach().numpy()

    plt.figure(figsize=(6, 6))
    plt.scatter(generated_data[:, 0], generated_data[:, 1],
                alpha=0.5, s=10, c='blue', label='Generated')
    plt.xlim(-2, 2)
    plt.ylim(-2, 2)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('GAN Generated Data')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()


# 查看生成效果
visualize_results(generator)

3. DCGAN - 深度卷积 GAN

DCGAN 是将卷积神经网络引入 GAN 的经典架构,大幅提升了图像生成质量。

3.1 DCGAN 架构要点

  • 使用转置卷积(Transposed Convolution)进行上采样生成图像
  • 使用带步长的卷积进行下采样判别图像
  • 在生成器和判别器中使用 BatchNorm(但输出层和输入层不使用)
  • 生成器使用 ReLU,判别器使用 LeakyReLU

3.2 DCGAN 实现

实例

import torch
import torch.nn as nn

# ── DCGAN 生成器 ─────────────────────────────────
class DCGenerator(nn.Module):
    """
    DCGAN 生成器:使用转置卷积上采样
    """

    def __init__(self, noise_dim=100, channels=3, features_g=64):
        super().__init__()
        self.noise_dim = noise_dim

        # 输入: noise_dim x 1 x 1
        self.net = nn.Sequential(
            # 转置卷积: (batch, features_g*8, 4, 4)
            nn.ConvTranspose2d(noise_dim, features_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),

            # (batch, features_g*4, 8, 8)
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),

            # (batch, features_g*2, 16, 16)
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),

            # (batch, features_g, 32, 32)
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),

            # 输出: (batch, channels, 64, 64)
            nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
            nn.Tanh()  # 输出范围 [-1, 1]
        )

    def forward(self, x):
        # x: (batch, noise_dim) -> (batch, noise_dim, 1, 1)
        x = x.view(x.size(0), x.size(1), 1, 1)
        return self.net(x)


# ── DCGAN 判别器 ─────────────────────────────────
class DCDiscriminator(nn.Module):
    """
    DCGAN 判别器:使用卷积下采样
    """

    def __init__(self, channels=3, features_d=64):
        super().__init__()

        # 输入: (batch, channels, 64, 64)
        self.net = nn.Sequential(
            # (batch, features_d, 32, 32)
            nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # (batch, features_d*2, 16, 16)
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # (batch, features_d*4, 8, 8)
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # (batch, features_d*8, 4, 4)
            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_d * 8),
            nn.LeakyReLU(0.2, inplace=True),

            # 输出: (batch, 1, 1, 1)
            nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(x.size(0), -1)


# 测试网络
noise_dim = 100
generator = DCGenerator(noise_dim=noise_dim, channels=3, features_g=64)
discriminator = DCDiscriminator(channels=3, features_d=64)

# 测试前向传播
noise = torch.randn(2, noise_dim)
fake_images = generator(noise)
print(f"生成图像形状: {fake_images.shape}")  # torch.Size([2, 3, 64, 64])

decision = discriminator(fake_images)
print(f"判别结果形状: {decision.shape}")     # torch.Size([2, 1])

3.3 完整 DCGAN 训练代码

实例

# ── 训练配置 ─────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

NOISE_DIM = 100
LEARNING_RATE = 0.0002
BETA1 = 0.5  # Adam 参数

# 创建网络
generator = DCGenerator(noise_dim=NOISE_DIM).to(device)
discriminator = DCDiscriminator().to(device)

# 优化器
g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))

criterion = nn.BCELoss()

# ── 训练循环 ─────────────────────────────────────
fixed_noise = torch.randn(64, NOISE_DIM, device=device)  # 用于可视化

def train_dcgan(generator, discriminator, g_optimizer, d_optimizer, criterion,
                num_epochs, device, fixed_noise):
    G_losses = []
    D_losses = []

    for epoch in range(num_epochs):
        for batch_idx in range(100):  # 假设每个 epoch 有 100 个 batch
            # 训练判别器
            discriminator.zero_grad()

            # 真实图像(假设已有)
            # real_images = ...
            # 这里用随机噪声模拟
            real_images = torch.randn(32, 3, 64, 64).to(device)

            batch_size = real_images.size(0)
            labels_real = torch.ones(batch_size, 1).to(device)
            labels_fake = torch.zeros(batch_size, 1).to(device)

            # 真实图像损失
            output = discriminator(real_images)
            d_loss_real = criterion(output, labels_real)

            # 生成图像损失
            noise = torch.randn(batch_size, NOISE_DIM).to(device)
            fake_images = generator(noise)
            output = discriminator(fake_images.detach())
            d_loss_fake = criterion(output, labels_fake)

            # 总损失
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()

            # 训练生成器
            generator.zero_grad()

            noise = torch.randn(batch_size, NOISE_DIM).to(device)
            fake_images = generator(noise)
            output = discriminator(fake_images)
            g_loss = criterion(output, labels_real)  # 希望生成图像被判定为真

            g_loss.backward()
            g_optimizer.step()

            # 记录损失
            if batch_idx % 50 == 0:
                G_losses.append(g_loss.item())
                D_losses.append(d_loss.item())
                print(f"[{epoch}/{num_epochs}][{batch_idx}/100] "
                      f"D_loss: {d_loss:.4f} | G_loss: {g_loss:.4f}")

    return G_losses, D_losses


# 开始训练
# G_losses, D_losses = train_dcgan(generator, discriminator, g_optimizer,
#                                   d_optimizer, criterion, 5, device, fixed_noise)
print("DCGAN 架构已定义完成,可以开始训练!")

4. GAN 的训练技巧

4.1 常见问题与解决方案

问题 原因 解决方案
模式崩溃(Mode Collapse) 生成器只生成有限的几种样本 使用 WGAN、添加小批量判别、使用标签平滑
判别器过强 生成器梯度消失,无法学习 训练生成器多次、降低判别器学习率、使用 LeakyReLU
训练不稳定 GAN 目标函数非凸,难以收敛 使用谱归一化、梯度惩罚、learning rate warmup
生成质量差 网络容量不足或训练不足 增加网络深度、使用更多训练数据、训练更长时间

4.2 损失函数改进

原始 GAN 使用 JS 散度,存在梯度消失问题。WGAN 使用 Wasserstein 距离更加稳定:

实例

# WGAN 损失函数(替代 BCE)
def wgan_d_loss(real_pred, fake_pred):
    """判别器损失:真实样本得分高,生成样本得分低"""
    return -(torch.mean(real_pred) - torch.mean(fake_pred))

def wgan_g_loss(fake_pred):
    """生成器损失:让生成样本得分高"""
    return -torch.mean(fake_pred)

# 梯度惩罚(Gradient Penalty)- WGAN-GP
def gradient_penalty(discriminator, real_images, fake_images, device):
    """WGAN-GP 梯度惩罚项"""
    batch_size = real_images.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)

    # 在真实和生成图像之间插值
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    interpolated.requires_grad = True

    # 计算插值图像的判别器输出
    pred = discriminator(interpolated)

    # 计算梯度
    gradients = torch.autograd.grad(
        outputs=pred,
        inputs=interpolated,
        grad_outputs=torch.ones_like(pred),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # 计算梯度范数
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = ((gradient_norm - 1) ** 2).mean()

    return penalty

4.3 谱归一化(Spectral Normalization)

谱归一化可以稳定 GAN 训练,控制判别器的 Lipschitz 常数:

实例

import torch.nn.utils.spectral_norm as spectral_norm

# 使用谱归一化的判别器
class SNDiscriminator(nn.Module):
    def __init__(self, channels=3, features_d=64):
        super().__init__()
        self.net = nn.Sequential(
            spectral_norm(nn.Conv2d(channels, features_d, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(features_d, features_d * 2, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(features_d * 4, 1, 4, 1, 0)),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(x.size(0), -1)

5. 条件 GAN (cGAN)

条件 GAN 允许指定生成数据的类别标签,实现有条件的生成。

5.1 cGAN 架构

实例

import torch
import torch.nn as nn

class ConditionalGenerator(nn.Module):
    """条件生成器:同时接收噪声和类别标签"""
    def __init__(self, noise_dim, num_classes, embed_dim, img_channels, features_g=64):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, embed_dim)

        # 将噪声和标签嵌入拼接
        self.net = nn.Sequential(
            nn.Linear(noise_dim + embed_dim, features_g * 8 * 4 * 4),
            nn.BatchNorm1d(features_g * 8 * 4 * 4),
            nn.ReLU(),
            # 然后接转置卷积(类似 DCGAN)
            # ...
        )

    def forward(self, noise, labels):
        # 将类别标签嵌入到与噪声相同的维度
        label_embedding = self.label_emb(labels)
        # 拼接噪声和标签嵌入
        x = torch.cat([noise, label_embedding], dim=1)
        return self.net(x)


class ConditionalDiscriminator(nn.Module):
    """条件判别器:同时接收图像和类别标签"""
    def __init__(self, img_channels, num_classes, embed_dim, features_d=64):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, embed_dim)

        # 将图像和标签嵌入拼接
        self.net = nn.Sequential(
            nn.Conv2d(img_channels + embed_dim, features_d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            # ...
        )

    def forward(self, img, labels):
        # 将标签嵌入调整到与图像相同的空间尺寸
        label_embedding = self.label_emb(labels)
        # 调整形状以便拼接
        label_embedding = label_embedding.unsqueeze(2).unsqueeze(3)
        label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
        # 拼接图像和标签
        x = torch.cat([img, label_embedding], dim=1)
        return self.net(x)

6. GAN 评估指标

6.1 常用评估指标

指标 描述 优点 缺点
Inception Score (IS) 使用 Inception v3 评估生成图像的质量和多样性 计算简单,与人类判断有一定相关性 不评估过拟合,无法检测模式崩溃
Fréchet Inception Distance (FID) 计算真实图像和生成图像在特征空间的距离 对噪声和模式崩溃更敏感 需要大量样本,计算较慢
人工评估 人工判断生成图像质量 最准确反映人类感知 主观、耗时

6.2 FID 计算实现

实例

import numpy as np
from scipy import linalg

def calculate_fid(real_activations, fake_activations):
    """
    计算 Fréchet Inception Distance
    real_activations: 真实图像的特征向量 (N, dim)
    fake_activations: 生成图像的特征向量 (N, dim)
    """

    # 计算均值和协方差
    mu1, sigma1 = real_activations.mean(axis=0), np.cov(real_activations, rowvar=False)
    mu2, sigma2 = fake_activations.mean(axis=0), np.cov(fake_activations, rowvar=False)

    # 计算 FID
    diff = mu1 - mu2
    # 计算协方差矩阵的和
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)

    # 避免复数
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid


# 简化示例:使用随机数据
np.random.seed(42)
real_acts = np.random.randn(1000, 2048)  # Inception v3 输出维度
fake_acts = np.random.randn(1000, 2048)

fid_score = calculate_fid(real_acts, fake_acts)
print(f"FID Score: {fid_score:.2f}")
# FID 越低越好,通常小于 50 表示较好的生成质量

7. 常见 GAN 变体

GAN 发展至今产生了众多变体,适用于不同的应用场景:

模型 全称 特点 适用场景
DCGAN Deep Convolutional GAN 使用卷积网络,生成高质量图像 图像生成
WGAN Wasserstein GAN 使用 Wasserstein 距离,训练更稳定 稳定训练
WGAN-GP WGAN with Gradient Penalty 梯度惩罚替代权重裁剪 稳定训练
CGAN Conditional GAN 加入条件信息,可控生成 条件生成
CycleGAN Cycle-Consistent GAN 无监督图像转换 风格迁移
StyleGAN Style-Based GAN 风格控制,高质量人脸生成 人脸生成
BigGAN Big GAN 大规模、高质量图像生成 高分辨率图像
ProGAN Progressive Growing GAN 渐进式增大分辨率 高分辨率生成