提示信息

深度学习中的 GAN 生成对抗网络原理推导与代码实现(Python)

球迷女孩 已关注

生成对抗网络(Generative Adversarial Networks, GAN)的核心思想是通过两个相互竞争的神经网络模型来实现数据生成。GAN 由一个生成器(Generator)和一个判别器(Discriminator)组成。生成器负责生成类似真实数据的假数据,判别器则负责鉴别数据的真假。通过这个对抗过程,生成器不断改进生成的数据质量。

GAN 的基本原理

  1. 生成器(G):

    • 输入:随机噪声 ( z )(通常来自标准正态分布)。
    • 输出:生成的数据 ( G(z) ),尝试尽可能地近似真实数据分布。
  2. 判别器(D):

    • 输入:真实数据样本或生成数据样本。
    • 输出:概率标量,表示输入数据是“真实”的概率。
  3. 对抗过程:

    • 判别器 ( D ) 尽量将真实样本标记为 1,将生成样本标记为 0。
    • 生成器 ( G ) 尽量让 ( D ) 将生成的样本标记为 1。
  4. 损失函数:

    • 生成器的目标是最大化判别器对生成数据的错误判断,因此优化的是:( \max_{G} \, \log(D(G(z))) )。
    • 判别器的目标是最大化对样本的正确分类,因此优化的是:( \max_{D} \, [\log(D(x)) + \log(1 - D(G(z)))] )。

代码实现

下面是一个简单的 GAN 实现示例,采用 PyTorch:

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torchvision import datasets, transforms
  5. from torch.utils.data import DataLoader
  6. # 定义生成器
  7. class Generator(nn.Module):
  8. def __init__(self):
  9. super(Generator, self).__init__()
  10. self.main = nn.Sequential(
  11. nn.Linear(100, 256),
  12. nn.ReLU(True),
  13. nn.Linear(256, 512),
  14. nn.ReLU(True),
  15. nn.Linear(512, 1024),
  16. nn.ReLU(True),
  17. nn.Linear(1024, 28*28),
  18. nn.Tanh()
  19. )
  20. def forward(self, input):
  21. return self.main(input).view(-1, 1, 28, 28)
  22. # 定义判别器
  23. class Discriminator(nn.Module):
  24. def __init__(self):
  25. super(Discriminator, self).__init__()
  26. self.main = nn.Sequential(
  27. nn.Linear(28*28, 1024),
  28. nn.LeakyReLU(0.2, inplace=True),
  29. nn.Linear(1024, 512),
  30. nn.LeakyReLU(0.2, inplace=True),
  31. nn.Linear(512, 256),
  32. nn.LeakyReLU(0.2, inplace=True),
  33. nn.Linear(256, 1),
  34. nn.Sigmoid()
  35. )
  36. def forward(self, input):
  37. return self.main(input.view(-1, 28*28))
  38. # 定义超参数
  39. batch_size = 64
  40. lr = 0.0002
  41. num_epochs = 200
  42. # 数据加载
  43. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
  44. train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  45. dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
  46. # 实例化模型
  47. G = Generator()
  48. D = Discriminator()
  49. # 定义优化器
  50. criterion = nn.BCELoss()
  51. optimizerD = optim.Adam(D.parameters(), lr=lr)
  52. optimizerG = optim.Adam(G.parameters(), lr=lr)
  53. # 训练 GAN
  54. for epoch in range(num_epochs):
  55. for i, (data, _) in enumerate(dataloader):
  56. # 更新判别器
  57. real_data = data
  58. batch_size = real_data.size(0)
  59. labels_real = torch.ones(batch_size, 1)
  60. labels_fake = torch.zeros(batch_size, 1)
  61. # 判别真实数据
  62. output = D(real_data)
  63. lossD_real = criterion(output, labels_real)
  64. # 判别生成数据
  65. noise = torch.randn(batch_size, 100)
  66. fake_data = G(noise)
  67. output = D(fake_data.detach())
  68. lossD_fake = criterion(output, labels_fake)
  69. # 总判别器损失
  70. lossD = lossD_real + lossD_fake
  71. optimizerD.zero_grad()
  72. lossD.backward()
  73. optimizerD.step()
  74. # 更新生成器
  75. output = D(fake_data)
  76. lossG = criterion(output, labels_real)
  77. optimizerG.zero_grad()
  78. lossG.backward()
  79. optimizerG.step()
  80. print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD.item()}, Loss G: {lossG.item()}")
  81. print("训练完成!")

解释

  1. 数据加载与预处理:

    • 使用 MNIST 数据集作为真实数据。
    • 数据标准化到 ([-1, 1]) 范围以匹配生成器输出的范围。
  2. 模型定义:

    • Generator: 将随机向量转换为数据样本。
    • Discriminator: 判断数据是现实的还是生成的。
  3. 优化与损失:

    • 使用 binary cross entropy loss 作为损失函数。
    • 使用 Adam 优化器用来更新模型权重。
  4. 训练循环:

    • 交替训练判别器和生成器,确保生成器的输出逐渐逼近真实数据分布。

这种简单的 GAN 结构适合生成小规模的数据集(例如 MNIST 等),更复杂的数据生成通常需要引入更高级的技术,如卷积层(DCGAN),或者使用改进的训练策略。

    遇到难题? "AI大模型GPT4.0、GPT" 是你的私人解答专家! 点击按钮去提问......
球迷女孩 关注 已关注

最近一次登录:2024-11-20 22:04:49   

暂时还没有签名,请关注我或评论我的文章
×
免费图表工具,画流程图、架构图