深度学习中的 GAN 生成对抗网络原理推导与代码实现(Python)
生成对抗网络(Generative Adversarial Networks, GAN)的核心思想是通过两个相互竞争的神经网络模型来实现数据生成。GAN 由一个生成器(Generator)和一个判别器(Discriminator)组成。生成器负责生成类似真实数据的假数据,判别器则负责鉴别数据的真假。通过这个对抗过程,生成器不断改进生成的数据质量。
GAN 的基本原理
生成器(G):
- 输入:随机噪声 ( z )(通常来自标准正态分布)。
- 输出:生成的数据 ( G(z) ),尝试尽可能地近似真实数据分布。
判别器(D):
- 输入:真实数据样本或生成数据样本。
- 输出:概率标量,表示输入数据是“真实”的概率。
对抗过程:
- 判别器 ( D ) 尽量将真实样本标记为 1,将生成样本标记为 0。
- 生成器 ( G ) 尽量让 ( D ) 将生成的样本标记为 1。
损失函数:
- 生成器的目标是最大化判别器对生成数据的错误判断,因此优化的是:( \max_{G} \, \log(D(G(z))) )。
- 判别器的目标是最大化对样本的正确分类,因此优化的是:( \max_{D} \, [\log(D(x)) + \log(1 - D(G(z)))] )。
代码实现
下面是一个简单的 GAN 实现示例,采用 PyTorch:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 28*28),
nn.Tanh()
)
def forward(self, input):
return self.main(input).view(-1, 1, 28, 28)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(28*28, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input.view(-1, 28*28))
# 定义超参数
batch_size = 64
lr = 0.0002
num_epochs = 200
# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 实例化模型
G = Generator()
D = Discriminator()
# 定义优化器
criterion = nn.BCELoss()
optimizerD = optim.Adam(D.parameters(), lr=lr)
optimizerG = optim.Adam(G.parameters(), lr=lr)
# 训练 GAN
for epoch in range(num_epochs):
for i, (data, _) in enumerate(dataloader):
# 更新判别器
real_data = data
batch_size = real_data.size(0)
labels_real = torch.ones(batch_size, 1)
labels_fake = torch.zeros(batch_size, 1)
# 判别真实数据
output = D(real_data)
lossD_real = criterion(output, labels_real)
# 判别生成数据
noise = torch.randn(batch_size, 100)
fake_data = G(noise)
output = D(fake_data.detach())
lossD_fake = criterion(output, labels_fake)
# 总判别器损失
lossD = lossD_real + lossD_fake
optimizerD.zero_grad()
lossD.backward()
optimizerD.step()
# 更新生成器
output = D(fake_data)
lossG = criterion(output, labels_real)
optimizerG.zero_grad()
lossG.backward()
optimizerG.step()
print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD.item()}, Loss G: {lossG.item()}")
print("训练完成!")
解释
数据加载与预处理:
- 使用 MNIST 数据集作为真实数据。
- 数据标准化到 ([-1, 1]) 范围以匹配生成器输出的范围。
模型定义:
Generator
: 将随机向量转换为数据样本。Discriminator
: 判断数据是现实的还是生成的。
优化与损失:
- 使用
binary cross entropy loss
作为损失函数。 - 使用
Adam
优化器用来更新模型权重。
- 使用
训练循环:
- 交替训练判别器和生成器,确保生成器的输出逐渐逼近真实数据分布。
这种简单的 GAN 结构适合生成小规模的数据集(例如 MNIST 等),更复杂的数据生成通常需要引入更高级的技术,如卷积层(DCGAN),或者使用改进的训练策略。