PyTorch 生成对抗网络 (GAN)
生成对抗网络(Generative Adversarial Network,GAN)是深度学习中最具创意的模型架构之一。它通过让两个神经网络相互对抗、相互学习,最终能够生成非常逼真的数据。GAN 广泛应用于图像生成、风格迁移、数据增强等场景。
1. GAN 核心原理
GAN 的核心思想来源于博弈论中的"零和博弈"。它包含两个相互对抗的网络:
- 生成器(Generator):学习生成假数据,目标是让判别器无法区分生成数据与真实数据
- 判别器(Discriminator):学习区分真实数据与生成数据,目标是尽可能准确判断
两者在训练过程中相互对抗、不断提升,最终达到纳什均衡状态。
1.1 GAN 的目标函数
GAN 的训练目标可以表示为以下minimax游戏:
\[ \min_G \max_D \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] \]
其中:
- \(G\) 表示生成器网络
- \(D\) 表示判别器网络
- \(x\) 表示真实数据
- \(z\) 表示随机噪声向量(通常服从标准正态分布)
- \(G(z)\) 表示生成器根据噪声生成的假数据
1.2 训练过程解读
GAN 的训练分为两个阶段:
第一阶段:训练判别器
固定生成器,提升判别器的分辨能力:
\[ \max_D \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] \]
第二阶段:训练生成器
固定判别器,提升生成器的欺骗能力:
\[ \min_G \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] \]
实际训练中,通常先训练判别器 k 步,再训练生成器 1 步,以保持平衡。
2. 基础 GAN 实现
下面实现一个最简单的 GAN——用于生成二维数据点。
2.1 定义生成器和判别器
实例
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# 设置随机种子
torch.manual_seed(42)
# ── 生成器网络 ──────────────────────────────────────
class Generator(nn.Module):
"""
生成器:从随机噪声生成数据
输入:噪声向量 (batch_size, noise_dim)
输出:生成数据 (batch_size, data_dim)
"""
def __init__(self, noise_dim, data_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(noise_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, data_dim),
# 输出不激活,GAN 会学习合适的分布
)
def forward(self, x):
return self.net(x)
# ── 判别器网络 ──────────────────────────────────────
class Discriminator(nn.Module):
"""
判别器:区分真实数据与生成数据
输入:数据点 (batch_size, data_dim)
输出:真实数据的概率 (batch_size, 1)
"""
def __init__(self, data_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(data_dim, hidden_dim),
nn.LeakyReLU(0.2), # LeakyReLU 防止梯度消失
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # 输出概率
)
def forward(self, x):
return self.net(x)
# 超参数
NOISE_DIM = 16
DATA_DIM = 2
HIDDEN_DIM = 64
BATCH_SIZE = 128
# 创建网络
generator = Generator(NOISE_DIM, DATA_DIM, HIDDEN_DIM)
discriminator = Discriminator(DATA_DIM, HIDDEN_DIM)
print(f"生成器参数量: {sum(p.numel() for p in generator.parameters()):,}")
print(f"判别器参数量: {sum(p.numel() for p in discriminator.parameters()):,}")
2.2 训练循环
实例
lr = 0.001
g_optimizer = optim.Adam(generator.parameters(), lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
# 损失函数:二分类交叉熵
criterion = nn.BCELoss()
# ── 训练数据:环形分布 ──────────────────────────
def generate_real_data(batch_size):
"""生成环形分布的真实数据"""
angles = torch.rand(batch_size) * 2 * torch.pi
radius = 1.0 + torch.randn(batch_size) * 0.1 # 半径约为 1
x = radius * torch.cos(angles)
y = radius * torch.sin(angles)
return torch.stack([x, y], dim=1)
# ── 训练循环 ──────────────────────────────────────
NUM_EPOCHS = 1000
d_losses = []
g_losses = []
for epoch in range(NUM_EPOCHS):
# 1. 训练判别器
# 生成假数据
noise = torch.randn(BATCH_SIZE, NOISE_DIM)
fake_data = generator(noise).detach() # detach 避免计算生成器梯度
# 生成真实数据
real_data = generate_real_data(BATCH_SIZE)
# 判别器损失
real_pred = discriminator(real_data)
fake_pred = discriminator(fake_data)
d_loss = criterion(real_pred, torch.ones_like(real_pred)) + \
criterion(fake_pred, torch.zeros_like(fake_pred))
# 更新判别器
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 2. 训练生成器
# 生成新一批假数据
noise = torch.randn(BATCH_SIZE, NOISE_DIM)
fake_data = generator(noise)
# 生成器损失:让判别器认为生成的数据是真实的
fake_pred = discriminator(fake_data)
g_loss = criterion(fake_pred, torch.ones_like(fake_pred))
# 更新生成器
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# 记录损失
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
if (epoch + 1) % 100 == 0:
print(f"Epoch {epoch+1:4d} | D_loss: {d_loss:.4f} | G_loss: {g_loss:.4f}")
print("训练完成!")
2.3 可视化生成结果
实例
def visualize_results(generator, num_samples=1000):
noise = torch.randn(num_samples, NOISE_DIM)
generated_data = generator(noise).detach().numpy()
plt.figure(figsize=(6, 6))
plt.scatter(generated_data[:, 0], generated_data[:, 1],
alpha=0.5, s=10, c='blue', label='Generated')
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.xlabel('x')
plt.ylabel('y')
plt.title('GAN Generated Data')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# 查看生成效果
visualize_results(generator)
3. DCGAN - 深度卷积 GAN
DCGAN 是将卷积神经网络引入 GAN 的经典架构,大幅提升了图像生成质量。
3.1 DCGAN 架构要点
- 使用转置卷积(Transposed Convolution)进行上采样生成图像
- 使用带步长的卷积进行下采样判别图像
- 在生成器和判别器中使用 BatchNorm(但输出层和输入层不使用)
- 生成器使用 ReLU,判别器使用 LeakyReLU
3.2 DCGAN 实现
实例
import torch.nn as nn
# ── DCGAN 生成器 ─────────────────────────────────
class DCGenerator(nn.Module):
"""
DCGAN 生成器:使用转置卷积上采样
"""
def __init__(self, noise_dim=100, channels=3, features_g=64):
super().__init__()
self.noise_dim = noise_dim
# 输入: noise_dim x 1 x 1
self.net = nn.Sequential(
# 转置卷积: (batch, features_g*8, 4, 4)
nn.ConvTranspose2d(noise_dim, features_g * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True),
# (batch, features_g*4, 8, 8)
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
# (batch, features_g*2, 16, 16)
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
# (batch, features_g, 32, 32)
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
# 输出: (batch, channels, 64, 64)
nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
nn.Tanh() # 输出范围 [-1, 1]
)
def forward(self, x):
# x: (batch, noise_dim) -> (batch, noise_dim, 1, 1)
x = x.view(x.size(0), x.size(1), 1, 1)
return self.net(x)
# ── DCGAN 判别器 ─────────────────────────────────
class DCDiscriminator(nn.Module):
"""
DCGAN 判别器:使用卷积下采样
"""
def __init__(self, channels=3, features_d=64):
super().__init__()
# 输入: (batch, channels, 64, 64)
self.net = nn.Sequential(
# (batch, features_d, 32, 32)
nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# (batch, features_d*2, 16, 16)
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 2),
nn.LeakyReLU(0.2, inplace=True),
# (batch, features_d*4, 8, 8)
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 4),
nn.LeakyReLU(0.2, inplace=True),
# (batch, features_d*8, 4, 4)
nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 8),
nn.LeakyReLU(0.2, inplace=True),
# 输出: (batch, 1, 1, 1)
nn.Conv2d(features_d * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x).view(x.size(0), -1)
# 测试网络
noise_dim = 100
generator = DCGenerator(noise_dim=noise_dim, channels=3, features_g=64)
discriminator = DCDiscriminator(channels=3, features_d=64)
# 测试前向传播
noise = torch.randn(2, noise_dim)
fake_images = generator(noise)
print(f"生成图像形状: {fake_images.shape}") # torch.Size([2, 3, 64, 64])
decision = discriminator(fake_images)
print(f"判别结果形状: {decision.shape}") # torch.Size([2, 1])
3.3 完整 DCGAN 训练代码
实例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
NOISE_DIM = 100
LEARNING_RATE = 0.0002
BETA1 = 0.5 # Adam 参数
# 创建网络
generator = DCGenerator(noise_dim=NOISE_DIM).to(device)
discriminator = DCDiscriminator().to(device)
# 优化器
g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))
criterion = nn.BCELoss()
# ── 训练循环 ─────────────────────────────────────
fixed_noise = torch.randn(64, NOISE_DIM, device=device) # 用于可视化
def train_dcgan(generator, discriminator, g_optimizer, d_optimizer, criterion,
num_epochs, device, fixed_noise):
G_losses = []
D_losses = []
for epoch in range(num_epochs):
for batch_idx in range(100): # 假设每个 epoch 有 100 个 batch
# 训练判别器
discriminator.zero_grad()
# 真实图像(假设已有)
# real_images = ...
# 这里用随机噪声模拟
real_images = torch.randn(32, 3, 64, 64).to(device)
batch_size = real_images.size(0)
labels_real = torch.ones(batch_size, 1).to(device)
labels_fake = torch.zeros(batch_size, 1).to(device)
# 真实图像损失
output = discriminator(real_images)
d_loss_real = criterion(output, labels_real)
# 生成图像损失
noise = torch.randn(batch_size, NOISE_DIM).to(device)
fake_images = generator(noise)
output = discriminator(fake_images.detach())
d_loss_fake = criterion(output, labels_fake)
# 总损失
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# 训练生成器
generator.zero_grad()
noise = torch.randn(batch_size, NOISE_DIM).to(device)
fake_images = generator(noise)
output = discriminator(fake_images)
g_loss = criterion(output, labels_real) # 希望生成图像被判定为真
g_loss.backward()
g_optimizer.step()
# 记录损失
if batch_idx % 50 == 0:
G_losses.append(g_loss.item())
D_losses.append(d_loss.item())
print(f"[{epoch}/{num_epochs}][{batch_idx}/100] "
f"D_loss: {d_loss:.4f} | G_loss: {g_loss:.4f}")
return G_losses, D_losses
# 开始训练
# G_losses, D_losses = train_dcgan(generator, discriminator, g_optimizer,
# d_optimizer, criterion, 5, device, fixed_noise)
print("DCGAN 架构已定义完成,可以开始训练!")
4. GAN 的训练技巧
4.1 常见问题与解决方案
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 模式崩溃(Mode Collapse) | 生成器只生成有限的几种样本 | 使用 WGAN、添加小批量判别、使用标签平滑 |
| 判别器过强 | 生成器梯度消失,无法学习 | 训练生成器多次、降低判别器学习率、使用 LeakyReLU |
| 训练不稳定 | GAN 目标函数非凸,难以收敛 | 使用谱归一化、梯度惩罚、learning rate warmup |
| 生成质量差 | 网络容量不足或训练不足 | 增加网络深度、使用更多训练数据、训练更长时间 |
4.2 损失函数改进
原始 GAN 使用 JS 散度,存在梯度消失问题。WGAN 使用 Wasserstein 距离更加稳定:
实例
def wgan_d_loss(real_pred, fake_pred):
"""判别器损失:真实样本得分高,生成样本得分低"""
return -(torch.mean(real_pred) - torch.mean(fake_pred))
def wgan_g_loss(fake_pred):
"""生成器损失:让生成样本得分高"""
return -torch.mean(fake_pred)
# 梯度惩罚(Gradient Penalty)- WGAN-GP
def gradient_penalty(discriminator, real_images, fake_images, device):
"""WGAN-GP 梯度惩罚项"""
batch_size = real_images.size(0)
alpha = torch.rand(batch_size, 1, 1, 1).to(device)
# 在真实和生成图像之间插值
interpolated = alpha * real_images + (1 - alpha) * fake_images
interpolated.requires_grad = True
# 计算插值图像的判别器输出
pred = discriminator(interpolated)
# 计算梯度
gradients = torch.autograd.grad(
outputs=pred,
inputs=interpolated,
grad_outputs=torch.ones_like(pred),
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
# 计算梯度范数
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
4.3 谱归一化(Spectral Normalization)
谱归一化可以稳定 GAN 训练,控制判别器的 Lipschitz 常数:
实例
# 使用谱归一化的判别器
class SNDiscriminator(nn.Module):
def __init__(self, channels=3, features_d=64):
super().__init__()
self.net = nn.Sequential(
spectral_norm(nn.Conv2d(channels, features_d, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(features_d, features_d * 2, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(features_d * 4, 1, 4, 1, 0)),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x).view(x.size(0), -1)
5. 条件 GAN (cGAN)
条件 GAN 允许指定生成数据的类别标签,实现有条件的生成。
5.1 cGAN 架构
实例
import torch.nn as nn
class ConditionalGenerator(nn.Module):
"""条件生成器:同时接收噪声和类别标签"""
def __init__(self, noise_dim, num_classes, embed_dim, img_channels, features_g=64):
super().__init__()
self.label_emb = nn.Embedding(num_classes, embed_dim)
# 将噪声和标签嵌入拼接
self.net = nn.Sequential(
nn.Linear(noise_dim + embed_dim, features_g * 8 * 4 * 4),
nn.BatchNorm1d(features_g * 8 * 4 * 4),
nn.ReLU(),
# 然后接转置卷积(类似 DCGAN)
# ...
)
def forward(self, noise, labels):
# 将类别标签嵌入到与噪声相同的维度
label_embedding = self.label_emb(labels)
# 拼接噪声和标签嵌入
x = torch.cat([noise, label_embedding], dim=1)
return self.net(x)
class ConditionalDiscriminator(nn.Module):
"""条件判别器:同时接收图像和类别标签"""
def __init__(self, img_channels, num_classes, embed_dim, features_d=64):
super().__init__()
self.label_emb = nn.Embedding(num_classes, embed_dim)
# 将图像和标签嵌入拼接
self.net = nn.Sequential(
nn.Conv2d(img_channels + embed_dim, features_d, 4, 2, 1),
nn.LeakyReLU(0.2),
# ...
)
def forward(self, img, labels):
# 将标签嵌入调整到与图像相同的空间尺寸
label_embedding = self.label_emb(labels)
# 调整形状以便拼接
label_embedding = label_embedding.unsqueeze(2).unsqueeze(3)
label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
# 拼接图像和标签
x = torch.cat([img, label_embedding], dim=1)
return self.net(x)
6. GAN 评估指标
6.1 常用评估指标
| 指标 | 描述 | 优点 | 缺点 |
|---|---|---|---|
| Inception Score (IS) | 使用 Inception v3 评估生成图像的质量和多样性 | 计算简单,与人类判断有一定相关性 | 不评估过拟合,无法检测模式崩溃 |
| Fréchet Inception Distance (FID) | 计算真实图像和生成图像在特征空间的距离 | 对噪声和模式崩溃更敏感 | 需要大量样本,计算较慢 |
| 人工评估 | 人工判断生成图像质量 | 最准确反映人类感知 | 主观、耗时 |
6.2 FID 计算实现
实例
from scipy import linalg
def calculate_fid(real_activations, fake_activations):
"""
计算 Fréchet Inception Distance
real_activations: 真实图像的特征向量 (N, dim)
fake_activations: 生成图像的特征向量 (N, dim)
"""
# 计算均值和协方差
mu1, sigma1 = real_activations.mean(axis=0), np.cov(real_activations, rowvar=False)
mu2, sigma2 = fake_activations.mean(axis=0), np.cov(fake_activations, rowvar=False)
# 计算 FID
diff = mu1 - mu2
# 计算协方差矩阵的和
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
# 避免复数
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
return fid
# 简化示例:使用随机数据
np.random.seed(42)
real_acts = np.random.randn(1000, 2048) # Inception v3 输出维度
fake_acts = np.random.randn(1000, 2048)
fid_score = calculate_fid(real_acts, fake_acts)
print(f"FID Score: {fid_score:.2f}")
# FID 越低越好,通常小于 50 表示较好的生成质量
7. 常见 GAN 变体
GAN 发展至今产生了众多变体,适用于不同的应用场景:
| 模型 | 全称 | 特点 | 适用场景 |
|---|---|---|---|
| DCGAN | Deep Convolutional GAN | 使用卷积网络,生成高质量图像 | 图像生成 |
| WGAN | Wasserstein GAN | 使用 Wasserstein 距离,训练更稳定 | 稳定训练 |
| WGAN-GP | WGAN with Gradient Penalty | 梯度惩罚替代权重裁剪 | 稳定训练 |
| CGAN | Conditional GAN | 加入条件信息,可控生成 | 条件生成 |
| CycleGAN | Cycle-Consistent GAN | 无监督图像转换 | 风格迁移 |
| StyleGAN | Style-Based GAN | 风格控制,高质量人脸生成 | 人脸生成 |
| BigGAN | Big GAN | 大规模、高质量图像生成 | 高分辨率图像 |
| ProGAN | Progressive Growing GAN | 渐进式增大分辨率 | 高分辨率生成 |
