zoukankan      html  css  js  c++  java
  • NLP(二十五):Faiss+SentenceBert应用

    一、Sentence_Bert代码

    from sentence_transformers import SentenceTransformer, SentencesDataset, util
    from sentence_transformers import InputExample, evaluation, losses
    from torch.utils.data import DataLoader
    import pandas as pd
    from root_path import root
    import os
    import time
    
    class SentenceBert(object):
        def __init__(self):
            self.model = SentenceTransformer('distiluse-base-multilingual-cased')
            data_path = os.path.join(root, "confusion_detection", "data", "sim_data")
            self.train_data = pd.read_csv(os.path.join(data_path, "train.csv"), sep="	")
            self.train_data.sample(frac=1)
            self.val_data = pd.read_csv(os.path.join(data_path, "val.csv"), sep="	")
            self.val_data.sample(frac=1)
            self.test_data = pd.read_csv(os.path.join(data_path, "test.csv"), sep="	")
            self._model_dir = os.path.join(root, "confusion_detection", "checkpoints", "sentence_bert")
            _timestamp = str(int(time.time()))
            self.save_path = os.path.join(self._model_dir, _timestamp)
    
        def get_input(self):
            """从原始数据生成训练数据集"""
            train_datas = []
            y = self.train_data["y"]
            s1 = self.train_data["s1"]
            s2 = self.train_data["s2"]
            for s1, s2, l in zip(s1, s2, y):
                train_datas.append(InputExample(texts=[s1, s2], label=float(l)))
            return train_datas
    
        def eval_examples(self):
            """从原始数据得到验证集"""
            sentences1, sentences2, scores = [], [], []
            for s1, s2, l in zip(self.val_data["s1"],
                                 self.val_data["s2"],
                                 self.val_data["y"]):
                sentences1.append(s1)
                sentences2.append(s2)
                scores.append(float(l))
            return sentences1, sentences2, scores
    
        def train(self):
            """训练并保存模型"""
            train_datas = self.get_input()
            sentences1, sentences2, scores = self.eval_examples()
            evaluator = evaluation.EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
            # Define your train dataset, the dataloader and the train loss
            train_dataset = SentencesDataset(train_datas, self.model)
            train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=32)
            train_loss = losses.CosineSimilarityLoss(self.model)
    
            if not os.path.exists(self.save_path):
                os.makedirs(self.save_path)
            model_path = os.path.join(self.save_path, "sentence_albert.model")
            # Tune the model
            self.model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=50, warmup_steps=100,
                      evaluator=evaluator, evaluation_steps=200, output_path=model_path)
    
        def test_examples(self):
            """从原始数据得到测试数据"""
            sentences1, sentences2, scores = [], [], []
            for s1, s2, l in zip(self.test_data["s1"],
                                 self.test_data["s2"],
                                 self.test_data["y"]):
                sentences1.append(s1)
                sentences2.append(s2)
                scores.append(float(l))
            return sentences1, sentences2, scores
    
        def test_similar(self):
            """相似度测评,准确度测评"""
            model_path = os.path.join(self._model_dir, "1624013214", "sentence_albert.model")
            model = SentenceTransformer(model_path)
            sentences1, sentences2, scores = self.test_examples()
            evaluator_1 = evaluation.EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
            print(model.evaluate(evaluator_1))
            evaluator_2 = evaluation.BinaryClassificationEvaluator(sentences1, sentences2, scores)
            print(model.evaluate(evaluator_2))
    
        def encode_sentence(self):
            """获取模型向量"""
            model_path = os.path.join(self._model_dir, "1624013214", "sentence_albert.model")
            model = SentenceTransformer(model_path)
            # Sentences are encoded by calling model.encode()
            emb1 = model.encode('对什么事情?')
            emb2 = model.encode("对啊,什么事?")
            print(emb1)
            print(emb2)
            cos_sim = util.pytorch_cos_sim(emb1, emb2)
            print("Cosine-Similarity:", cos_sim)
    
        def index_faiss(self):
            import numpy as np
            import faiss  # make faiss available
            model_path = os.path.join(self._model_dir, "1624013214", "sentence_albert.model")
            model = SentenceTransformer(model_path)
            ALL = self.test_data["s1"].tolist() + self.test_data["s2"].tolist()
            ALL = list(set(ALL))
            # ALL是4739
            DT = model.encode(ALL)
            DT = np.array(DT, dtype=np.float32)
            # DT[0].shape[0] = 512, 维度,DT维度(4739, 512)
            # https://waltyou.github.io/Faiss-Introduce/
            index = faiss.IndexFlatL2(DT[0].shape[0])  # build the index
            print(index.is_trained) #True
            index.add(DT)  # add vectors to the index
            print(index.ntotal) #4739
            k = 10  # we want to see 10 nearest neighbors
            aim = 220
            """
            (1, 512)
            [[ 220 3817 4396  404 2934  139 1647 4343 3430 3690]]
            [[0.         0.0626016  0.0932295  0.09561109 0.11395724 0.11639294
              0.13688776 0.17873172 0.19773623 0.20516896]]
            """
            D, I = index.search(DT[aim:aim + 1], k)  # sanity check
            print(DT[aim:aim + 1].shape)
            print(I)
            print(D)
            print([ALL[i] for i in I[0]])
    
    if __name__ == '__main__':
        SentenceBert().encode_sentence()

    二、Faiss结合Sentence_Bert代码

    from root_path import root
    from sentence_transformers import SentenceTransformer, SentencesDataset, util
    import os
    import pickle
    import jieba
    import faiss
    import time
    import numpy as np
    
    class FaissIndex(object):
        def __init__(self):
            self.bert_model = SentenceTransformer('distiluse-base-multilingual-cased')
            self.data_path = os.path.join(root, "confusion_detection", "data", "raw_data", "all.txt")
            self.labels = []
            self.sentences = []
            self.raws = []
            with open(self.data_path, "r", encoding="utf8") as f:
                for line in f.readlines():
                    data_tuple = line.replace("
    ", "").split("  ")
                    self.labels.append(data_tuple[0])
                    senten = data_tuple[1]
                    self.raws.append(senten)
                    sentence = list(jieba.cut(senten))
                    sen = " ".join(sentence)
                    self.sentences.append(sen)
            self.tfidf_model = os.path.join(root, "confusion_detection", "checkpoints",
                                            "tf_idf", "tfidf.model")
            self.tfidf_faiss = os.path.join(root, "confusion_detection", "checkpoints",
                                            "faiss_model", "tfidf_faiss.model")
    
            self.bert_model = os.path.join(root, "confusion_detection", "checkpoints",
                                           "sentence_bert", "1624013214", "sentence_albert.model")
            self.bert_faiss = os.path.join(root, "confusion_detection",
                                           "checkpoints", "faiss_model", "bert_faiss.model")
    
            self.bert = SentenceTransformer(self.bert_model)
    
            with open(self.tfidf_model, 'rb') as f:
                self.tfidf_vectorizer = pickle.load(f)
    
        def write_tfidf_faiss(self):
    
            res = self.tfidf_vectorizer.transform(self.sentences).toarray()
            index = faiss.IndexFlatL2(res.shape[1])
            index.add(np.array(res).astype("float32"))
            faiss.write_index(index, self.tfidf_faiss)
    
        def read_tfidf_faiss(self, query):
            index = faiss.read_index(self.tfidf_faiss)
            t = time.time()
            sentence = list(jieba.cut(query))
            sen = " ".join(sentence)
            res = self.tfidf_vectorizer.transform([sen]).toarray()
            k = 5
            top_k = index.search(np.array(res).astype("float32"), k)
            result = [self.raws[_id] for _id in top_k[1].tolist()[0]]
            y = [self.labels[_id] for _id in top_k[1].tolist()[0]]
            print(result)
            print(y)
            return result,y
    
        def write_bert_faiss(self):
    
            res = self.bert.encode(self.raws)
            index = faiss.IndexFlatL2(res.shape[1])
            index.add(np.array(res).astype("float32"))
            out_path = self.bert_faiss
            faiss.write_index(index, out_path)
    
        def read_bert_faiss(self, query):
            index = faiss.read_index(self.bert_faiss)
            t = time.time()
            query_vector = self.bert.encode([query])
            k = 10
            top_k = index.search(query_vector, k)
            result = [self.raws[_id] for _id in top_k[1].tolist()[0]]
            y = [self.labels[_id] for _id in top_k[1].tolist()[0]]
            return top_k[0][0], result,y
    
        def detect(self):
            index = faiss.read_index(self.bert_faiss)
            t = time.time()
            query_vector = self.bert.encode(self.raws)
            k = 5
            score, line_index = index.search(query_vector, k)
            # 均值
            avg = np.mean(score)
            print(avg)
    
    if __name__ == '__main__':
        _faiss = FaissIndex()
        while True:
            q = input()
            top_k, result, y = _faiss.read_bert_faiss(q)
            print(top_k)
            print(result)
            print(y)
  • 相关阅读:
    20155207 2016-2017-2《Java程序设计》课程总结
    Mycp补交作业
    20155207 实验五 网络编程与安全
    20155206 随堂作业
    20155206 《Java程序设计》实验三实验报告
    20155206 2016-2017-2 《Java程序设计》第十周学习总结
    20155206 《JAVA程序设计》实验二(JAVA面向对象程序设计)实验报告
    20155206 2016-2017-2 《Java程序设计》第9周学习总结
    20155206 2016-2017-2 《Java程序设计》第8周学习总结
    20155206 实验一《Java开发环境的熟悉》实验报告
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14919448.html
Copyright © 2011-2022 走看看