语义相似度模型SBERT——一个挛生网络的优美范例 | Word count: 1.5k | Reading time: 5min | Post View:
语义相似度模型
SBERT——一个挛生网络的优美范例
论文地址:https://arxiv.org/abs/1908.10084
论文中文翻译:https://www.cnblogs.com/gczr/p/12874409.html
源码下载:https://github.com/UKPLab/sentence-transformers
相关网站:https://www.sbert.net/
“论文中文翻译”已相当清楚,故本篇不再翻译,只简单介绍 SBERT
的原理,以及训练和使用中文相似度模型的方法和效果。
原理
挛生网络 Siamese network(后简称 SBERT),其中 Siamese
意为“连体人”,即两人共用部分器官。SBERT 模型的子网络都使用 BERT
模型,且两个 BERT 模型共享参数。当对比 A,B
两个句子相似度时,它们分别输入 BERT
网络,输出是两组表征句子的向量,然后计算二者的相似度;利用该原理还可以使用向量聚类,实现无监督学习任务。
挛生网络有很多应用,比如使用图片搜索时,输入照片将其转换成一组向量,和库中的其它图片对比,找到相似度最高(距离最近)的图片;在问答场景中,找到与用户输入文字最相近的标准问题,然后给出相应解答;对各种文本标准化等等。
衡量语义相似度是自然语言处理中的一个重要应用,BERT
源码中并未给出相应例程(run_glue.py
只是在其示例框架内的简单示例),真实场景使用时需要做大量修改;而 SBERT
提供了现成的方法解决了相似度问题,并在速度上更有优势,直接使用更方便。
SBERT 对 Pytorch 进行了封装,简单使用该工具时,不仅不需要了解太多
BERT API 的细节,Pytorch 相关方法也不多,下面来看看其具体用法。
配置环境
需要注意的是机器需要能正常配置 BERT 运行环境,如
GPU+CUDA+Pytorch+Transformer 匹配版本。
1 $ pip install sentence_transformers
下载源码
1 $ git clone https://github.com/UKPLab/sentence-transformers.git
模型预测
在未进行调优(fine-tune)前,使用预训练的通用中文 BERT
模型也可以达到一定效果,下例是从几个选项中找到与目标最相近的字符串。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 from sentence_transformers import SentenceTransformer import scipy.spatial embedder = SentenceTransformer('bert-base-chinese') corpus = ['这是一支铅笔', '关节置换术', '我爱北京天安门', ] corpus_embeddings = embedder.encode(corpus) # 待查询的句子 queries = ['心脏手术','中国首都在哪里'] query_embeddings = embedder.encode(queries) # 对于每个句子,使用余弦相似度查询最接近的n个句子 closest_n = 2 for query, query_embedding in zip(queries, query_embeddings): distances = scipy.spatial.distance.cdist([query_embedding], corpus_embeddings, "cosine")[0] # 按照距离逆序 results = zip(range(len(distances)), distances) results = sorted(results, key=lambda x: x[1]) print("======================") print("Query:", query) print("Result:Top 5 most similar sentences in corpus:") for idx, distance in results[0:closest_n]: print(corpus[idx].strip(), "(Score: %.4f)" % (1-distance))
训练中文模型
模型训练方法
训练原理:https://www.sbert.net/docs/training/overview.html
训练示例说明:https://www.sbert.net/examples/training/sts/README.html
训练示例代码:examples/training/sts/training_stsbenchmark.py
训练中文模型
把示例中的 bert-base-cased 换成
bert-base-chinese,即可下载和使用中文模型。需要注意的是:中文和英文词库不同,不能将中文模型用于英文数据训练。
下载中文训练数据
下载信贷相关数据,csv 数据 7M 多,约 10W
条训练数据,可在下例中使用
1 2 $ git clone https://github.com/lixuanhng/NLP_related_projects.git $ ls NLP_related_projects/BERT/Bert_sim/data
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 from torch.utils.data import DataLoaderimport mathfrom sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, utilfrom sentence_transformers.evaluation import EmbeddingSimilarityEvaluatorfrom sentence_transformers.readers import InputExampleimport loggingfrom datetime import datetimeimport sysimport osimport pandas as pdmodel_name = 'bert-base-chinese' train_batch_size = 16 num_epochs = 4 model_save_path = 'test_output' logging.basicConfig(format ='%(asctime)s - %(message)s' , datefmt='%Y-%m-%d %H:%M:%S' , level=logging.INFO, handlers=[LoggingHandler()]) word_embedding_model = models.Transformer(model_name) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True , pooling_mode_cls_token=False , pooling_mode_max_tokens=False ) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) train_samples = [] dev_samples = [] test_samples = [] def load (path ): df = pd.read_csv(path) samples = [] for idx,item in df.iterrows(): samples.append(InputExample(texts=[item['sentence1' ], item['sentence2' ]], label=float (item['label' ]))) return samples train_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/train.csv' ) test_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/test.csv' ) dev_samples = load('/workspace/exports/git/NLP_related_projects/BERT/Bert_sim/data/dev.csv' ) train_dataloader = DataLoader(train_samples, shuffle=True , batch_size=train_batch_size) train_loss = losses.CosineSimilarityLoss(model=model) evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev' ) warmup_steps = math.ceil(len (train_dataloader) * num_epochs * 0.1 ) model.fit(train_objectives=[(train_dataloader, train_loss)], evaluator=evaluator, epochs=num_epochs, evaluation_steps=1000 , warmup_steps=warmup_steps, output_path=model_save_path) model = SentenceTransformer(model_save_path) test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test' ) test_evaluator(model, output_path=model_save_path)
测试结果
直接使用预训练的英文模型,测试集正确率 21%
直接使用预训练的中文模型,测试集正确率 30%
使用 1000 个用例的训练集,4 次迭代,测试集正确率 51%
使用 10000 个用例的训练集,4 次迭代,测试集正确率 68%
使用 100000 个用例的训练集,4 次迭代,测试集正确率 71%
一些技巧
除了设置超参数以外,也可通过构造训练数据来优化 SBERT
网络,比如:构造正例时,把知识“喂”给模型,如将英文缩写与对应中文作为正例对训练模型;构造反例时用容易混淆的句子对训练模型(文字相似但含义不同的句子;之前预测出错的实例,分析其原因,从而构造反例;使用知识构造容易出错的句子对),以替代之前的随机抽取反例。
参考