序列对抗网络 SeqGAN

SeqGAN 源自 2016 年的论文《SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient》,论文地址:https://arxiv.org/pdf/1609.05473.pdf。其核心是用生成对抗网络处理离散的序列数据。

之前介绍了使用 GAN 生成图像的方法,由于图像是连续数据,可以使用调整梯度的方法逐步生成图像,而离散数据很难使用梯度更新。在自然语言处理(NLP)中使用 GAN 生成文字时,由于词索引与词向量转换过程中数据不连续,微调参数可能不起作用;且普通 GAN 的判别模型只对生成数据整体打分,而文字一般都是逐词(token)生成,因此无法控制细节。SeqGAN 借鉴了强化学习(RL)的策略,解决了 GAN 应用于离散数据的问题。

概念

与基本的 GAN 算法一样,SeqGAN 的基本原理也是迭代训练生成模型 G 和判别模型 D。假设用 G 生成一个词序列组成句子,由 D 来判别这个句子是训练集中的真实句子(True data),还是模型生成的句子(Generate);最终目标是用模型 G 生成以假乱真的句子,让 D 无法分辨。其操作过程如下:

图片摘自论文

与普通对抗不同的是,在单次操作中,模型多次调用生成模型 G 和判别模型 D。以生成文字为例,右侧的每一个红圈是一个生成词的操作,State 为已生成的词串,在生成下一个词 Next action 时,先调用生成模型 G 生成多个备选项,然后使用判别模型对各个选项评分(reward),根据评分选择最好的策略 Policy,并调整策略模型(Policy Gradient)。

强化学习

SeqGAN 主要借鉴了强化学习中的方法,如果不了解强化学习很难看懂论文中的公式和推导,下面先对强化学习做一个简单的介绍。

强化学习的核心是在实践中通过不断试错来学习最好的策略,一般强化学习学到的是一系列决策,其目标是最大化长期收益,例如围棋比赛中当前的操作不仅需要考虑接下来一步的收益,还需要考虑未来多步的收益。

强化学习有几个核心概念:状态 s(State)、动作 a(Action)、奖励 r(Reward)。以生成词系列为例,假设词系列是 Y1:T=(y1,...,yt,...,yT),在第 t 个时间步,状态 s 是先前已生成的词 (y1,...,yt−1),动作 a 是如何选择要生成的词 yt,这也是生成模型的工作 Gθ(yt|Y1:t−1),它通过前 t-1 个词以及模型参数θ来选择下一个词,确定了该词之后,状态也随之改变成 s’,对应词 (y1,...,yt),以此类推,最终生成的系列 (y1,...,yt,...,yT),对序列的评分就是奖励 r,如果生成的系列成功地骗过了判别模型 D,则得 1 分,如果被识别出是机器生成的则得 0 分。

强化学习中还有两个重要概念:动作价值 action-value 和状态价值 state-value。简单地说,动作价值就是在某个状态选择某一动作是好是坏,如果能确定每一个动作对应的价值,就很容易做出决定。动作价值不仅与当前动作有关,还涉及此动作之后一系列动作带来的价值。状态价值也是同理,它表示某个状态的好坏。

SeqGAN 原理

SeqGAN 中生成模型 G 的目标是最大化期望奖励 reward,简单说就是做出可能是奖励最大的选择,其公式如下:

上式中 J 是目标函数,E[] 是期望,R 是序列整体的奖励值,s 是状态,θ是生成模型的参数,y 是生成的下一个词(动作 action),G 是生成模型,D 是判别模型,Q 是动作价值(action-value)。简单地解释公式:希望得到一组生成模型 G 参数θ;能在 s0 处做出最佳选择,获取最大回报 RT,而如何选择动作又取决于动作的价值 Q。

动作价值算法如下:

动作价值是由判别函数 D 判定的,第 T 个时间步是最后一个时间步,上式中列出的是判别函数对完整系列的打分。若判别该序列为真实文本,则奖励值 R 最大。

在生成第 t 个词时,如何选择(动作a)涉及前期已生成的 t-1 个词(状态s),以及后续可能的情况,假设此时用模型 Gβ生成N个备选词串(Yt:T),再用判别模型 D 分别对生成的 N 句(Y1:T)打分,此时使用了蒙特卡洛方法(MC),如下式所示:

这里的生成模型 Gβ与前面 Gθ通常使用同样的模型参数,有时为了优化速度也可使用不同模型参数。这里使用的蒙特卡洛算法,像下棋一样,不仅要考虑当前一步的最优解,还需要考虑接下来多步组合后的最优解,用于探索此节点以及此节点后续节点(Yt:T)的可能性,也叫 roll-out 展开,是蒙特卡洛搜索树中的核心技巧。

根据不同的时间步,采取不同的动作价值计算方法:

在最后一个时间步 t=T 时,直接使用判别函数 D 计算价值;在其它时间步,使用生成模型 Gβ和蒙特卡洛算法生成N个后续备选项,用判别函数 D 打分并计算分数的均值。

SeqGAN 与 GAN 模型相同,在训练生成器 G 的同时,判别器 D 也迭代地更新其参数。

此处公式与 GAN 相同,即优化判别模型 D 的参数φ,使其对真实数据 Pdata 尽量预测为真,对模型 Gθ生成的数据尽量预测为假。

主要流程

其主要流程如下:

图片摘自论文
  • 程序定义了基本生成器 Gθ,roll-out 生成器 Gβ,判别器 D,以及训练集 S。

  • 用 MLE(最大似然估计)预训练生成器 G。(2 行)

  • 用生成器生成的数据和训练集数据预训练判别器 D。(4-5 行)

  • 进入迭代对抗训练:(6 行)

  • 训练生成器(7-13 行) 在每一个时间步计算 Q,这是最关键的一步,它利用判别器D、roll-out 生成器 Gβ以及蒙特卡罗树搜索计算行为价值,然后更新 policy gradient 策略梯度。

  • 训练判别器(14-17 行) 将训练数据作为正例,生成器生成的样例作为反例训练判别模型 D。

代码

推荐以下代码:

TensorFlow 代码(官方):https://github.com/LantaoYu/SeqGAN 

Pytorch 代码:https://github.com/suragnair/seqGAN

其中 Pytorch 代码比较简单,与论文中描述的模型不完全一致,比如它的 G 和 D 都使用 GRU 作为基础模型,也没有实现 rollout 逻辑,只是一个简化的版本,优点在于代码简单,适合入门。