提示信息

PyTorch之生成对抗网络

沦陷的痛 已关注

生成对抗网络(Generative Adversarial Networks,GANs)是由Ian Goodfellow等人在2014年提出的一种深度学习模型。它由两个相互竞争的神经网络组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成看似真实的假数据,而判别器则需要区分这些假数据与真实数据。通过这种对抗的方式,GANs可以生成与真实数据分布非常接近的样本。

下面简要介绍如何使用PyTorch实现一个简单的GAN模型:

1. 环境准备

确保已安装PyTorch,可以通过pip安装:

  1. pip install torch torchvision

2. 数据准备

通常情况下,GAN会应用到图片生成任务中,因此我们可以使用MNIST数据集作为例子。

  1. from torchvision import datasets, transforms
  2. from torch.utils.data import DataLoader
  3. transform = transforms.Compose([
  4. transforms.ToTensor(),
  5. transforms.Normalize((0.5,), (0.5,)),
  6. ])
  7. mnist_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
  8. dataloader = DataLoader(mnist_data, batch_size=64, shuffle=True)

3. 定义生成器和判别器

生成器网络将随机噪声向量转换为对抗样本。判别器网络则是一个二分类器,用于区分真实样本和生成样本。

  1. import torch
  2. import torch.nn as nn
  3. class Generator(nn.Module):
  4. def __init__(self, input_size, hidden_size, output_size):
  5. super(Generator, self).__init__()
  6. self.main = nn.Sequential(
  7. nn.Linear(input_size, hidden_size),
  8. nn.LeakyReLU(0.2),
  9. nn.Linear(hidden_size, hidden_size * 2),
  10. nn.LeakyReLU(0.2),
  11. nn.Linear(hidden_size * 2, hidden_size * 4),
  12. nn.LeakyReLU(0.2),
  13. nn.Linear(hidden_size * 4, output_size),
  14. nn.Tanh()
  15. )
  16. def forward(self, x):
  17. return self.main(x)
  18. class Discriminator(nn.Module):
  19. def __init__(self, input_size, hidden_size):
  20. super(Discriminator, self).__init__()
  21. self.main = nn.Sequential(
  22. nn.Linear(input_size, hidden_size * 4),
  23. nn.LeakyReLU(0.2),
  24. nn.Linear(hidden_size * 4, hidden_size * 2),
  25. nn.LeakyReLU(0.2),
  26. nn.Linear(hidden_size * 2, hidden_size),
  27. nn.LeakyReLU(0.2),
  28. nn.Linear(hidden_size, 1),
  29. nn.Sigmoid()
  30. )
  31. def forward(self, x):
  32. return self.main(x)

4. 定义损失函数和优化器

使用二元交叉熵损失和Adam优化器。

  1. criterion = nn.BCELoss()
  2. lr = 0.0002 # 学习率
  3. latent_size, hidden_size, image_size = 100, 256, 784
  4. generator = Generator(latent_size, hidden_size, image_size)
  5. discriminator = Discriminator(image_size, hidden_size)
  6. g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
  7. d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

5. 训练GAN

在训练过程中,判别器根据生成样本和真实样本更新,然后生成器根据判别器的反馈改进生成样本。

  1. num_epochs = 100
  2. for epoch in range(num_epochs):
  3. for real_images, _ in dataloader:
  4. batch_size = real_images.size(0)
  5. real_images = real_images.view(batch_size, -1)
  6. # 创建标签
  7. real_labels = torch.ones(batch_size, 1)
  8. fake_labels = torch.zeros(batch_size, 1)
  9. # ---------------------
  10. # 训练判别器
  11. # ---------------------
  12. outputs = discriminator(real_images)
  13. d_loss_real = criterion(outputs, real_labels)
  14. real_score = outputs
  15. z = torch.randn(batch_size, latent_size)
  16. fake_images = generator(z)
  17. outputs = discriminator(fake_images)
  18. d_loss_fake = criterion(outputs, fake_labels)
  19. fake_score = outputs
  20. d_loss = d_loss_real + d_loss_fake
  21. d_optimizer.zero_grad()
  22. d_loss.backward()
  23. d_optimizer.step()
  24. # -----------------
  25. # 训练生成器
  26. # -----------------
  27. z = torch.randn(batch_size, latent_size)
  28. fake_images = generator(z)
  29. outputs = discriminator(fake_images)
  30. g_loss = criterion(outputs, real_labels)
  31. g_optimizer.zero_grad()
  32. g_loss.backward()
  33. g_optimizer.step()
  34. print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, \
  35. g_loss: {g_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')

上述代码实现了一个简单的GAN模型,可以生成看似真实的手写数字。要提高模型效果,您可以考虑使用更复杂的网络结构,以及更精细的训练技巧,比如渐进式增长、标签平滑或谱归一化等。

    遇到难题? "AI大模型GPT4.0、GPT" 是你的私人解答专家! 点击按钮去提问......
沦陷的痛 关注 已关注

最近一次登录:2024-11-09 17:28:50   

暂时还没有签名,请关注我或评论我的文章
×
免费图表工具,画流程图、架构图