英文题目:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

中文题目:从 BERT 中蒸馏指定任务知识到简单网络

论文地址:https://arxiv.org/pdf/1903.12136.pdf

领域:自然语言,深度学习

发表时间:2019

作者:Raphael Tang, 滑铁卢大学

被引量:226

代码和数据:https://github.com/qiangsiwei/bert_distill

阅读时间:2022.09.11

读后感

第一次对大型自然语言模型的蒸馏:将 BERT 模型蒸馏成 BiLSTM 模型。

介绍

在自然语言处理方面,随着 BERT,GPT 等大规模预训练模型的发展,浅层的深度学习模型似乎已经过时了。但由于资源的限制,又需要使用小而快的模型。

文章的动机是讨论:浅层模型是否真的不具备对文本的表示能力?并展示了针对于具体的任务,将 BERT 蒸馏成单层 BiLSTM 模型的方法和效果。也通过大模型(起初训练的复杂的模型,后称 Teacher/T)和小模型(蒸馏后的模型,后称 Student/S)完全不同的模型结构展示了蒸馏与模型结构无关。另外,之前蒸馏模型主要应用于图片建模,论文讨论了它在自然语言领域的使用方法。

方法

核心方法包含两部分:增加了 logit 回归目标;重建蒸馏训练数据集使训练更为有效。

模型结构

将 BERT 作为教师模型,使用单层的 BiLSTM 作为学习模型的非线性分类器,针对每一种下游任务使用不同模型。如图 -1 是对单句分类任务设计的学生模型。

Pasted image 20220912140516.png

图 -2 展示了用于预测句子匹配度的模型,它们的编码层共享同一 BiLSTM 模型。

Pasted image 20220912140643.png

为了更好地对比效果,在学生模型中,未使用注意力归一化等更多技巧。

蒸馏目标

学生模型的目标是在所有数据上,模拟老师模型的行为。除了最终的标签,老师模型预测出的概率也很重要。比如在情绪分类问题中,一些实例有很强的正面情绪,有一些情绪可能比较中性,所以除了是否,也需要预测程度。

一般预测标签的方法是:

Pasted image 20220912141318.png

文中使用了 logit 的优化方法,构造了蒸馏目标:用 MSE 来惩罚师生模型间的差异:

Pasted image 20220912141655.png

其中 z(B) 指的是老师模型 BERT,z(S) 指学生模型,在初步实验中,MSE 比软目标效果更好。

在实际训练时,也使用了传统的交叉熵(对真正目标的预测)和蒸馏损失相结合的方式,最终损失函数如下:

Pasted image 20220912142437.png

当使用有标签数据训练时,t 是实例的标签;使用无标签数据训练时,使用老师模型打标签。

蒸馏的数据增强

在蒸馏过程中,使用小的数据集不足以让老师模型展示出其所有知识,因此,使用了无标签数据扩充训练数据集,用老师模型对其打标签。

增强 NLP 数据比增强图像数据难度大,没办法使用扭曲等方法,做出的句子可能不够流畅。文中提出了几种数据增强方法:

  • 遮蔽:使用类似 BERT 的方法,这种方法能反应句中每个词对标签的贡献。
  • 基于词性的词替换:在词袋里找同一词性的词作替换,以保持原始数据的分布。
  • n-gram 采样:根据概率,随机采样 n 个连续的词,它是遮蔽方法的增强版。

实验

使用的是 BERT_LARGE 作为老师模型,针对特定任务精调,预测时获取预测的 logit 值,学生模型使用 300 维的 word2vec 作为词嵌入。主实验效果如表 -1 所示:

Pasted image 20220912144737.png

可以看到同样是使用 BiLSTM 方法,文中方法相较于其它方法有显著提升。

从表 -2 可以看到预测速度也有很大提升:

Pasted image 20220912144810.png