PyTorch深度学习实战:手算生成对抗网络(GAN)从零到一完整教程及代码分析
先聊几个核心判断:如果你正在研究如何低成本生成高质量的图像、音频或文本数据,那生成对抗网络(GAN)几乎是你绕不开的核心技术。它为什么这么厉害?我们一步步来看。
核心思想
可以这样理解,GAN的核心思想就像一场经典的“猫鼠游戏”。一个角色负责造假,想把假货做得以假乱真;另一个角色则负责打假,想尽办法找出赝品。两个玩家不断对抗、互相促进,最终结果是,造假者(生成器)的水平越来越高,到了足以骗过鉴定者(判别器)的地步,从而产生高度真实的数据。
具体来说,就是让生成器和判别器进行一场“军备竞赛”式的对抗训练。最终目标,是让生成器能产出高质量的“假数据”,其逼真程度甚至能骗过人类的眼睛。这些生成的图像,和原始的真实图像相比,几乎难辨雌雄。更重要的是,使用GAN来生成数据的成本相当低廉,而生成的结果可以直接应用于图像生成、数据增强、风格迁移等各个前沿领域。
网络结构
GAN由两个核心神经网络组成,它们的关系就像一组“冤家”:
生成器(Generator, G):负责“造假”。它的任务是从随机噪声中,生成与真实数据分布高度相似的假数据。
判别器(Discriminator, D):负责“打假”。它的任务是对输入的数据进行二分类——判断它究竟是来自真实世界的原始数据,还是生成器伪造出来的假数据。
两个网络在不断对抗中共同成长。最终,生成器能够学会生成足以乱真的数据,而判别器也练就了一双“火眼金睛”,越来越擅长分辨真伪。
生成器的目标,是生成判别器完全辨不出来的假数据。判别器的目标,则是精准区分真实数据和生成数据。经过多轮交替迭代的“相爱相杀”,生成器能生成和训练集极其相似的假图,而判别器也能准确判断出真实图片和生成图片的细微差别。
手算模拟
光说不练假把式,我们用一段代码来模拟一下这个过程,看看数据是如何在神经网络中流动并发生变化的。
训练生成器
首先,我们看看生成器的结构。它由一系列转置卷积层组成,目标是将一个100维的随机噪声向量,逐步“上采样”成一个64x64的RGB三通道图像。
生成器结构:
Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
输入噪声形状: torch.Size([2, 100, 1, 1])
第1层 | ConvTranspose2d | 输出形状: torch.Size([2, 512, 4, 4])
第2层 | BatchNorm2d | 输出形状: torch.Size([2, 512, 4, 4])
第3层 | ReLU | 输出形状: torch.Size([2, 512, 4, 4])
第4层 | ConvTranspose2d | 输出形状: torch.Size([2, 256, 8, 8])
第5层 | BatchNorm2d | 输出形状: torch.Size([2, 256, 8, 8])
第6层 | ReLU | 输出形状: torch.Size([2, 256, 8, 8])
第7层 | ConvTranspose2d | 输出形状: torch.Size([2, 128, 16, 16])
第8层 | BatchNorm2d | 输出形状: torch.Size([2, 128, 16, 16])
第9层 | ReLU | 输出形状: torch.Size([2, 128, 16, 16])
第10层 | ConvTranspose2d | 输出形状: torch.Size([2, 64, 32, 32])
第11层 | BatchNorm2d | 输出形状: torch.Size([2, 64, 32, 32])
第12层 | ReLU | 输出形状: torch.Size([2, 64, 32, 32])
第13层 | ConvTranspose2d | 输出形状: torch.Size([2, 3, 64, 64])
第14层 | Tanh | 输出形状: torch.Size([2, 3, 64, 64])
训练判别器
接下来是判别器的结构。与生成器相反,它由一系列卷积层组成,逐步“下采样”输入图像,最终输出一个介于0到1之间的概率值,表示输入图像是真实图像的可能性。
Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
===== 判别器逐层输出 =====
输入图像形状: torch.Size([2, 3, 64, 64])
第1层 | Conv2d | 输出形状: torch.Size([2, 64, 32, 32])
第2层 | LeakyReLU | 输出形状: torch.Size([2, 64, 32, 32])
第3层 | Conv2d | 输出形状: torch.Size([2, 128, 16, 16])
第4层 | BatchNorm2d | 输出形状: torch.Size([2, 128, 16, 16])
第5层 | LeakyReLU | 输出形状: torch.Size([2, 128, 16, 16])
第6层 | Conv2d | 输出形状: torch.Size([2, 256, 8, 8])
第7层 | BatchNorm2d | 输出形状: torch.Size([2, 256, 8, 8])
第8层 | LeakyReLU | 输出形状: torch.Size([2, 256, 8, 8])
第9层 | Conv2d | 输出形状: torch.Size([2, 512, 4, 4])
第10层 | BatchNorm2d | 输出形状: torch.Size([2, 512, 4, 4])
第11层 | LeakyReLU | 输出形状: torch.Size([2, 512, 4, 4])
第12层 | Conv2d | 输出形状: torch.Size([2, 1, 1, 1])
第13层 | Sigmoid | 输出形状: torch.Size([2, 1, 1, 1])
===== 最终结果 =====
判别器输出概率形状: torch.Size([2, 1, 1, 1])
输出为(batch_size, 1, 1, 1),每个值代表输入图像为真实图像的概率(0~1)
单个样本的训练
训练生成器
生成器的训练目标是“骗过”判别器。也就是说,它要最小化判别器对假图像的识别能力,让判别器认为这些假图像是真实的(标签为1)。
过程是这样的:生成器 netG 根据噪声生成假图片 fake_images(64_64_3)。然后,将这个 fake_images 送入判别器 netD,得到一个输出 output_fake(比如 0.001)。我们期望这个输出应该接近真实标签 torch.ones_like(output_fake, device=device),也就是 1。因此,我们计算 output_fake 和这个期望标签之间的损失,然后反向传播来更新生成器的参数。
训练判别器
判别器的训练则是一道“双向选择题”。首先,它要知道什么是真的。将真实照片(real_images)送入判别器,得到输出 output_real(比如 0.4126)。我们希望这个输出能接近 1,所以计算输出和真实标签(1)之间的损失 (lossD_real) 为 0.8846。
然后,它需要知道什么是假的。将生成器生成的假图片 fake_images 送入判别器,得到输出 output_fake(比如 0.001)。此时,我们期望它的标签是 0,所以计算这个输出和假标签(0)之间的损失 (lossD_fake) 为 0.6961。
最后,判别器的总损失等于这两个损失之和:
lossD = lossD_real + lossD_fake
下面是完整的单样本训练代码,它清晰地展示了这个交替训练的过程。
公式理解
目标函数
为了从数学上理解这场博弈,我们可以用公式来表述。GAN的终极目标,是求解一个“最小最大”博弈(minimax game)。下面这个公式,就是GAN的灵魂所在:
这个公式的含义可以拆解成:
判别器 D 的目标(最大化)
对于判别器 D 来说,它希望最大化这个目标函数,也就是希望它能完美区分真假。具体来看,它想让第一项(对真实数据输出为1的概率)尽可能大,同时让第二项(对假数据输出为1的概率)尽可能小。
而生成器 G 的目标则完全相反,它希望最小化这个目标函数,也就是让它生成的假数据在判别器面前看起来更像是真的。










