PyTorch之生成对抗网络
生成对抗网络(Generative Adversarial Networks,GANs)是由Ian Goodfellow等人在2014年提出的一种深度学习模型。它由两个相互竞争的神经网络组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成看似真实的假数据,而判别器则需要区分这些假数据与真实数据。通过这种对抗的方式,GANs可以生成与真实数据分布非常接近的样本。
下面简要介绍如何使用PyTorch实现一个简单的GAN模型:
1. 环境准备
确保已安装PyTorch,可以通过pip安装:
pip install torch torchvision
2. 数据准备
通常情况下,GAN会应用到图片生成任务中,因此我们可以使用MNIST数据集作为例子。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
mnist_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist_data, batch_size=64, shuffle=True)
3. 定义生成器和判别器
生成器网络将随机噪声向量转换为对抗样本。判别器网络则是一个二分类器,用于区分真实样本和生成样本。
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size * 2),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size * 2, hidden_size * 4),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size * 4, output_size),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(input_size, hidden_size * 4),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size * 4, hidden_size * 2),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size * 2, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
4. 定义损失函数和优化器
使用二元交叉熵损失和Adam优化器。
criterion = nn.BCELoss()
lr = 0.0002 # 学习率
latent_size, hidden_size, image_size = 100, 256, 784
generator = Generator(latent_size, hidden_size, image_size)
discriminator = Discriminator(image_size, hidden_size)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
5. 训练GAN
在训练过程中,判别器根据生成样本和真实样本更新,然后生成器根据判别器的反馈改进生成样本。
num_epochs = 100
for epoch in range(num_epochs):
for real_images, _ in dataloader:
batch_size = real_images.size(0)
real_images = real_images.view(batch_size, -1)
# 创建标签
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# ---------------------
# 训练判别器
# ---------------------
outputs = discriminator(real_images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
z = torch.randn(batch_size, latent_size)
fake_images = generator(z)
outputs = discriminator(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# -----------------
# 训练生成器
# -----------------
z = torch.randn(batch_size, latent_size)
fake_images = generator(z)
outputs = discriminator(fake_images)
g_loss = criterion(outputs, real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, \
g_loss: {g_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')
上述代码实现了一个简单的GAN模型,可以生成看似真实的手写数字。要提高模型效果,您可以考虑使用更复杂的网络结构,以及更精细的训练技巧,比如渐进式增长、标签平滑或谱归一化等。