定向写作模型 CTRL

介绍

CTRL 全称为 Conditional Transformer Language 有条件的文本生成模型,它始于 Salesforce 在 2019 年发布的论文《A Conditional Transformer Language Model for Controllable Generation》,该模型用于定向写作。论文地址如下:https://arxiv.org/pdf/1909.05858.pdf

这两年非常流行的 BERT 和 GPT-2 都基于 Transformer 模型,虽然代码量不大,逻辑也并不复杂,但是极大规模的数据量、训练强度和模型容量,以及利用无监督文本建模,使模型的能力空前强大,在一些领域已超过人类水平。

GPT-2 使用各种类型的文章训练模型,包括散文、小说、新闻、科技文章,用它写作的文章也综合了各种风格。如果想生成“金庸风格”的小说,则需要用所有金庸先生的小说重新训练模型;或者从原模型中提取特征构造新模型;也可以在原有模型基础上 fine-tuning。如需撰写新闻稿,则需要另行训练。

GPT-2 模型根据文章开头的内容,继续向后联想,控制不了文章的具体内容,因此也有人把它称为“造谣神器”。除了瞎编,它的实际用途又在何处?如何更好的控制文章的内容,生成有价值的文本。

CTRL 是继 GPT-2 后出现的写作模型,同样也基于 Transformer。与之前模型不同的是:它无需进一步训练就可以解决特定领域的具体问题。CTRL 模型可以指定文章的领域、风格、主题、时间、实体,实体间的关系,以及任务相关的行为等等,因此可以将其看成命题作文。它使用 140G 数据训练,参数规模 1.63 billion(16 亿,比 GPT-2 更大)。模型维度 1280 维,48 层 EncoderLayer,16 头 Attention,也是一个体量巨大的模型。

CTRL 模型的最大优势是在生成文本时可指定文章的类型,同一模型可以写作不同风格的文章。论文也举出了用同一开头续写不同类型文章的实例,比如高分评论和低分评论的差异;“刀”在购物评论和恐怖小说的场景中生成的不同文章;按时间、国家写出文章中涉及的不同总统等等。

不同的角度,有不同的答案。换言之,CTRL 关注了语料在不同场景中的不同含义。模型更符合实际应用的场景:使用者在同一时间,只可能生成某一特定类型,同时又希望单个模型支持生成各种类型的文章,CTRL 可视为多任务学习。

由人写一个故事梗概:时间、地点、人物、事件,用模型按照某种风格遣词造句填充内容。它与之前的问答系统、文章概要又有何区别呢?原来的模型先用无监督数据训练模型,然后用有标注的问与答,内容与概要代入模型调优。标注数据毕竟有限;CTRL 则海量的无监督数据进行了分类,这类似于简单的自动标注,让数据从一开始就更有针对性。

具体实现

CTRL 的核心思想是从无监督的海量数据集中定位文章所在的领域。大多数训练数据都从网络上抓取,在抓取过程中通过网址标题等信息估计它所在领域,并作为特征,代入训练。从而让模型写出各种类型的文章,同理在问答等领域中运用此技术,也可以更有针对性地解决问题。

CTRL 底层同样也基于 Transformer,使用了其中 Encoder 部分,模型底层改动不大。之前的模型是根据词序列中的前 n-1 个词计算下一个词 n 是哪个词的可能性,如式一所示:

(式一)

CTRL 又加入了条件 c,即文章的控制信息如类型,在计算概率的同时考虑条件 c。具体操作是在每一个序列的具体内容前加了入类型描述,使得在计算 Attention 过程中,类型与序列中的所有元素建立联系。如式二所示:

(式二)

代码中定义了一些常见,并且可以在抓取时识别的类型,如下图示:

除了类型,还支持将标题、下载的地址(有些下载地址中包含时间、实体等信息)……放在正文之前。除了上述改进,它还引入了新算法优化了后序词的筛选逻辑。

代码分析

CTRL 官方代码可从以下网址下载:https://github.com/salesforce/ctrl

其中包括 TensorFlow 和 Pytorch 两种实现方法,又细分为训练和应用两部分。以 Pytorch 为例,其核心代码主要在 pytorch_transformer.py 和 pytorch_generation.py 两个文件中。pytorch_transformer.py 主要实现了 Transformer 模型,其内容是基础版 Transformer 模型的 Encoder 部分。pytorch_generation.py 用于使用该模型撰写文章,其中包含解析数据和调用模型的方法。需要注意的是,使用该模型时,序列的第一位应为类型。模型训练部分在 training_utils 目录中用 TensorFlow 实现。

相对官方代码,更推荐 Hugging Face 团队发布的 Transformer 例程集,支持 TensorFlow 和 Pytorch 两种实现方式,其中也包含 CTRL 的实现,源码位置在:

https://github.com/huggingface/transformers/blob/master/src/transformers/

实现 Pytorch 版本 CTRL 的代码有:configuration_ctrl.py, modeling_ctrl.py, tokenization_ctrl.py,其中核心是 modeling_ctrl.py,建议读者用 debug 工具跟踪调用模型的完整流程,查看每一步的输入及输出,便可完全理解该模型。调用方法如下:

1
2
3
4
5
6
7
01 import torch
02 from transformers import CTRLTokenizer, CTRLModel
03 tokenizer = CTRLTokenizer.from_pretrained('ctrl')
04 model = CTRLModel.from_pretrained('ctrl')、input_ids =
05 torch.tensor(tokenizer.encode("Links Hello, my dog is cute",
06 add_special_tokens=True)).unsqueeze(0) # Batch size 1
07 outputs = model(input_ids)

注意:运行时将下载 6.5G 的预训练模型,虽然模型很大,但在没有 GPU 且机器性能不高的情况下也能正常调用模型预测部分。

总结

CTRL 不仅是一个自然语言处理问题的解决方案,同样也可应用到其它的序列处理问题之中。从 NLP 的演进可以看到,用无标注数据训练模型,生成一般性“常识”逐渐成为主流。人工不可能标注海量信息,目前,人们正试图使用更多知识和分析方法处理信息,并将知识融入模型结构,使人与工具更好地结合,并生成更加可控的模型。