生成对抗网络GAN
生成对抗网络 GAN
生成对抗网络 GAN 是一种深度学习模型,它源于 2014 年发表的论文:《Generative Adversarial Nets》,论文地址:https://arxiv.org/pdf/1406.2661.pdf。
GAN 的用途非常广泛,比如:有大量的卡通头像,想通过学习自动生成卡通图片,此问题只提供正例,可视为无监督学习问题。不可能通过人工判断大量数据。如何生成图片?如何评价生成的图片好坏?GAN 为此类问题提供了解决方法。
GAN 同时训练两个模型:生成模型G(Generative Model)和判别模型D(Discriminative Model),生成模型G的目标是学习数据的分布,判别模型D的目标是区别真实数据和模型G生成的数据。以生成卡通图片为例,生成网络 G 的目标是生成尽量真实的图片去欺骗判别网络 D。而 D 的目标就是尽量把 G 生成的图片和真实的图片分别开来。G 和 D 构成了一个动态的“博弈过程”,通过迭代双方能力都不断提高。
对抗网络近年来发展迅速。下图是近几年 ICASSP 会议上所有提交的论文中包含关键词 “generative”、“adversarial” 和 “reinforcement” 的论文数量统计。
用途
- 生成数据 GAN 常用于实现复杂分布上的无监督学习和半监督学习,学习数据的分布,模拟现有数据生成同类型的图片、文本、旋律等等。
- 数据增强 GAN 也用于扩展现有的数据集,即数据增强。使用它训练好的生成网络,可以在数据不足时用于补充数据。
- 生成特定数据 GAN 掌握了数据生成能力后,可通过加入限制,使模型生成特定类型的数据。比如改变图片风格,隐去敏感信息,实现诸如数据加密的功能。
- 使用判断模型 训练好的判别模型可以用于判断数据是否属于该类别,判断数据的真实性,以及判断异常数据。
原理
生成模型和判别模型
机器学习模型大体分为两类,生成模型(Generative Model)和判别模型(Discriminative Model)。生成模型学习得到联合概率分布 P(x,y),即特征 x 和标记 y 共同出现的概率,然后求条件概率分布。能够学习到数据生成的机制;判别模型学习得到条件概率分布 P(y|x),即在特征 x 出现的情况下标记 y 出现的概率。
具体算法
GAN 使用下式评估模型效果:
其中 Pdata 是真实数据的分布,式中左半部分将真实数据 x 代入判别模型 D(x),D 的输出范围是从 0-1,0 为假数据,1 为真数据;由于 x 是真实数据,D 模型希望 D(x)=1;右半部分将随机噪声 z 代入生成模型 G 产生模拟数据 G(z),并使用判别模型 D 判别它是否为真实数据,G 模型希望 D(G(z))=1,1-D(G(z))=0;相反,D 模型希望 D(G(z))=0,1-D(G(z))=1。也就是说,G 希望上式结果越小越好,而 D 希望上式结果越大越好。最终函数 V 既非最大,也非最小,找到双方的利益平衡点——生成数据完全拟合真实数据时达到纳什平衡。
论文中有推导过程,但有些跳步,从这里可以看到详细的推导过程:https://blog.csdn.net/susanzhang1231/article/details/76906340
其具体算法如下:
其中内部的 for 循环用于优化判别模型,先用随机噪声 z 生成 m 个数据,同时从真实数据中取 m 个数据,然后代入判别模型并根据判别结果优化模型参数;外部的 for 循环用于优化生成模型,可以看到生成模型只与公式中右侧计算相关。训练 k 次判别模型,训练 1 次生成模型,二者交替进行。
图中展示了两个模型的优化过程,其中黑色代表真实数据,绿色表示生成数据,蓝色表示判别结果;在图 (a) 中,生成模型没能很好地模拟真实数据分布,差别模型也效果不佳;图 (b) 优化了判别模型;图 (c) 随着生成模型的优化,生成数据逐渐接近真实数据;图 (d) 是最终效果,生成模型完美拟合真实数据,两种数据分布一致,判别模型将无法区分真实数据和生成数据 D(x)=1/2。
代码
推荐例程:https://github.com/RedstoneWill/MachineLearningInAction/blob/master/GAN/GAN_1.ipynb 整个例程不到 100 行,使用 GAN 方法拟合曲线。生成模型 G 和判别模型 D 都使用深度学习网络,且互过逆过程。其核心代码摘录如下:
1 | G = nn.Sequential( # 生成模型 |
整体迭代了 10000 次,每次迭代时取真实数据,并将随机数代入生成模型 G 生成模拟数据,将数据代入判别模型 D,然后根据损失函数调参。下图摘录了曲线拟合不同阶段的结果。
在图像处理领域使用生成对抗网络可以让模型学习生成特定类型的图片。推荐生成卡通人物图片的例程:
https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter7-GAN
生成动漫头像
(《深度学习框架 PyTorch:入门与实践》第七章的配套代码)
README 中有对应的资源下载地址,该程序只有两三百行代码,在没有 GPU 支持的机器上花几个小时也能训练完成。模型效果如下图所示:其中左图为第一次迭代的结果,右图为第 25 次迭代后的结果。
程序使用两个模型,其中一个用于生成图像,另一个用于判断图像是否为模型生成,两个深度学习网络互为逆过程,判别网络由多个卷积构成,用于层层提取特征,最终判断是否为真实图片,而生成网络由多个反卷积层构成,它通过随机噪声层层扩展生成图片。在博弈过程中两个模型各自优化,最终使模型具备生成特定类型图片的能力。
需要注意的是上例中的图片和曲线拟合都针对连续型数据,可以通过调整网络参数的方法逐渐逼近最佳值,对于生成文本一类的离散数据,则需要进一步修改模型。