Transformer-XL 框架

1 引入

Transformer-XL 超长上下文的注意力模型,出自 CMU 和 Google Brain 在 2019 年 1 月发表的论文:《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》。其中 XL 是 extra long 的缩写,意为额外长度。论文地址:https://arxiv.org/pdf/1901.02860.pdf 先简单举例 Transformer XL 与 Transformer 的区别。比如有以下数据:

“小说是以刻画人物形象为中心,通过完整的故事情节和环境描写来反映社会生活的文学体裁。”

如果把序列长度设为十个字,代入模型时数据被切分为:

“小说是以刻画人物形象”(序列一)

“为中,通过完整的故”(序列二)

……

在训练第二个序列时,它的意思是不完整的,Transformer 计算第二个序列中的第三字“心”时只能通过前两个字“为中”作为输入计算,而 Transformer-XL 可以把序列一中的十个字同时作为输入。

切分,尤其当英文中使用字符作为序列中的元素时,如果一个单词被切成两部分,分别位于前后两个序列中,必然影响模型效果。Transformer-XL 让之前的序列也能参与到当前序列的预测中来,由此解决了长序列依赖问题,以及序列切断问题。

Transformer-XL 既不像 GPT-2 一样使用海量数据训练,也没像 ERNIE 加入自然语言相关领域的知识,它主要通过改进模型的架构提高性能。它在字符级和词的级别上表现都很好(第一个在这两方面都超过 RNN 的自注意力模型),该方案不仅能应用于自然语言处理,在其它序列问题中也能发挥很好的效果。测试证明其在小数据集上也表现优异。

2 原理

Transformer 模型用 Self-Attention 自注意力机制替换了循环网络 RNN,而 Transformer-XL 再次使用 RNN,处理序列之间的连续性。Transformer XL 有两点重要创新:循环机制(Recurrence Mechanism)和相对位置编码(Relative Positional Encoding)。

循环机制

尽管 Transformer 模型可以处理较长的上下文关系,但仍需要在训练时将文章切分成固定长度的序列,再代入模型。而模型学到的也是各序列内部的规律。

  • 若不切分,字串太长,尤其是以字符为单位时,计算注意力过于复杂。
  • 按标点或按段切分,使程序效率下降。
  • 按固定长度切分后,前后的语义被切断。

Transformer XL 在输入数据的每个段上仍使用自注意力方法,并使用循环机制来学习连续段之间的依赖关系。

图片摘自论文

Transformer 模型的依赖关系如上图 (a) 中的灰色线条所示,在每个序列中,当前层的输入取决于前一层的输出;Transformer-XL 模型的依赖关系又加入了绿色连线,使当前层的输入取决于本序列和前一序列前一层的输出。具体公式如下:

其中 h 为隐藏层,n 为层数,r 为序列数,W 为模型参数。

式一计算当前第 n-1 隐藏层时,考虑了当前序列 r 上一个序列 r-1 的隐藏层值,其中 SG 意为 stands for stop-gradient 停止计算梯度,这样即运用了前一序列生成的数据,又不对其反向传播调参,节省了算力;中括号里的圆圈为连接两个隐藏层。

式二计算注意力所需的 q,k,v,q 用于查询当前位置,k 用于提供相关位置信息,v 用于提供相关位置的值。其中 k 和 v 使用了包括上个序列信息的隐藏层,而查询 q 只与当前序列相关。另外,第 n 个层是通过前一个序列和当前序列的 n-1 层算出来的,这和基础的循环网络 RNN 有所不同。

式三将 q,k,v 代入 Transformer 算法,计算隐藏层 n。

上述方法在计算过程中保留了前一个序列的隐藏层输出h,使得评价过程中不需要每次从头计算,也节约了算力。

相对位置编码

每个序列有其各自的位置编码,当使用多个序列作为输入时,则会出现位置冲突的问题。解决方法是将序列内部的绝对位置编码变为相对位置编码,并把在一开始计算位置编码,移到注意力打分时做计算。直觉上看,相对位置比绝对位置更重要,比如上例中“整”的前一个位置是“完”,一定比“整”在位置八时第七位置是“完”更合理。

具体方法是修改计算 Attention 的算法:

上式中 E 表示词嵌入,U 表示绝对位置信息,R 为相对位置信息,W 为模型参数,i 是查询元素,j 是相关元素。

式四使用绝对位置计算,先将词嵌入 E 和绝对位置 U 相加后,与参数相乘计算重要性权重,第二行将其展开。

式五使用相对位置计算,首先用相对位置 R 代替绝对位置 U;由于不需要绝对位置 Ui, 引入了 u,v 参数,取代 UiTWqT;另外将参数 Wk拆分成了 Wk,E和 Wk,R。式五又可分为四部分,含义分别是:

  1. j 的内容相对于 i 的影响

  2. i 与 j 的距离对于 i 的影响

  3. j 的内容相对于整体的影响

  4. i 与 j 的距离对于整体的影响

相对编码和循环网络二者结合后才能提升模型效果。如果只加循环网络,则前一序列的位置编码可能与当前序列的位置编码混淆;如果只使用相对位置编码,那么无法解决句子的截断问题。

3 代码

Git 代码地址:

https://github.com/kimiyoung/transformer-xl

作者提供了 Tensorflow、PyTorch 两种代码实现,以 Pytorch 为例,其模型实现在 pytorch/mem_transformer.py 代码中,其模型的代码几乎是 transformer 代码量的两倍,但命令名规则一致。

层结构、注意力、位置编码与基本的 Transformer 模型大同小异,改进的核心在:保存之前隐藏层数据的 mems 和计算相对位置的 Rel*LearnableMultiHeadAttn 部分。