zoukankan      html  css  js  c++  java
  • NLP(十三):word2vec和siamese-LSTM结合计算文本相似度

    一、定义数据加载

    my_dataset.py
    import torch.utils.data as data
    
    
    class MyDataset(data.Dataset):
        def __init__(self, texta, textb, label):
            self.texta = texta
            self.textb = textb
            self.label = label
    
        def __getitem__(self, item):
            texta = self.texta[item]
            textb = self.textb[item]
            label = self.label[item]
            return texta, textb, label
        def __len__(self):
            return len(self.texta)

    二、定义词嵌入

    my_word2vec.py
    from gensim.models.fasttext import FastText
    import torch
    import numpy as np
    import os
    
    class WordEmbedding(object):
        def __init__(self):
            parent_path = os.path.split(os.path.realpath(__file__))[0]
            self.root = parent_path[:parent_path.find("models")]  # E:personassemantics
            self.word_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "word_fasttext.model")
            self.char_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "char_fasttext.model")
            self.model = FastText.load(self.char_fasttext)
    
        def sentenceTupleToEmbedding(self, data1, data2):
            aCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data1])
            bCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data2])
            maxLen = max(aCutListMaxLen,bCutListMaxLen)
            seq_len = maxLen
            a = self.sqence_vec(data1, seq_len) #batch_size, sqence, embedding
            b = self.sqence_vec(data2, seq_len)
            return torch.FloatTensor(a), torch.FloatTensor(b)
        def sqence_vec(self, data, seq_len):
            data_a_vec = []
            for sequence_a in data:
                sequence_vec = []  # sequence * 128
                for word_a in list(str(sequence_a)):
                    if word_a in self.model.wv:
                        sequence_vec.append(self.model.wv[word_a])
                sequence_vec = np.array(sequence_vec)
                add = np.zeros((seq_len - sequence_vec.shape[0], 128))
                sequenceVec = np.vstack((sequence_vec, add))
                data_a_vec.append(sequenceVec)
            a_vec = np.array(data_a_vec)
            return a_vec
    
    if __name__ == '__main__':
        word = WordEmbedding()
        data1 = ("浙江杭州富阳区银湖街黄先生的外卖","浙江杭州富阳区银湖街黄先生的外卖")
        data2 = ("富阳区浙江富阳区银湖街道新常村","浙江杭州富阳区银湖街黄先生的外卖")
        a, b = word.sentenceTupleToEmbedding(data1, data2)
        print(a.shape)
        print(b)

    三、定义模型

    my_lstm.py
    import torch
    from torch import nn
    
    class SiameseLSTM(nn.Module):
        def __init__(self, input_size):
            super(SiameseLSTM, self).__init__()
            self.lstm = nn.LSTM(input_size=input_size, hidden_size=10, num_layers=1, batch_first=True)
            self.fc = nn.Sequential(
                nn.Linear(20, 1),
            )
    
        def forward(self, data1, data2):
            out1, (h1, c1) = self.lstm(data1)
            out2, (h2, c2) = self.lstm(data2)
            pre1 = out1[:, -1, :]
            pre2 = out2[:, -1, :]
            pre = torch.cat([pre1, pre2], dim=1)
            out = self.fc(pre)
            return out

    四、定义运行

    run__lstm.py
    import torch
    import os
    from torch.utils.data import DataLoader
    from my_dataset import MyDataset
    import pandas as pd
    import numpy as np
    from my_lstm import SiameseLSTM
    import torch.nn as nn
    from my_word2vec import WordEmbedding
    
    
    class RunLSTM():
        def __init__(self):
            self.learning_rate = 0.001
            self.device = torch.device("cpu")
            parent_path = os.path.split(os.path.realpath(__file__))[0]
            self.root = parent_path[:parent_path.find("models")]  # E:personassemantics
            self.train_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "train.csv")
            self.val_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "val.csv")
            self.test_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "test.csv")
            self.batch_size =64
            self.epoch = 50
            self.criterion = nn.BCEWithLogitsLoss().to(self.device)
            self.word = WordEmbedding()
            self.check_point = os.path.join(self.root, "checkpoints", "char_bilstm", "char_bilstm.pth")
    
        def get_loader(self, path):
            data = pd.read_csv(path, sep="	")
            d1, d2, y = data["s1"], data["s2"], list(data["y"])
    
            dataset = MyDataset(d1, d2, torch.LongTensor(y))
            data_iter = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
            return data_iter
    
        def binary_acc(self, preds, y):
            preds = torch.round(torch.sigmoid(preds))
            correct = torch.eq(preds, y).float()
            acc = correct.sum() / len(correct)
            return acc
    
        def train(self, mynet, train_iter, optimizer, criterion, epoch, device):
            avg_acc = []
            avg_loss = []
            mynet.train()
            for batch_id, (data1, data2, label) in enumerate(train_iter):
                try:
                    a, b = self.word.sentenceTupleToEmbedding(data1, data2)
                except Exception as e:
                    print("错误")
                a, b, label = a.to(device), b.to(device), label.to(device)
                distence = mynet(a, b)
                distence = distence.squeeze(1)
                loss = criterion(distence, label.float())
                acc = self.binary_acc(distence, label.float()).item()
                avg_acc.append(acc)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if batch_id % 100 == 0:
                    print("轮数:", epoch, "batch: ", batch_id, "训练损失:", loss.item(), "准确率:", acc)
                avg_loss.append(loss.item())
            avg_acc = np.array(avg_acc).mean()
            avg_loss = np.array(avg_loss).mean()
            print('train acc:', avg_acc)
            print("train loss", avg_loss)
    
        def eval(self, mynet, test_iter, criteon, epoch, device):
            mynet.eval()
            avg_acc = []
            avg_loss = []
            with torch.no_grad():
                for batch_id, (data1, data2, label) in enumerate(test_iter):
                    try:
                        a, b = self.word.sentenceTupleToEmbedding(data1, data2)
                    except Exception as e:
                        continue
    
                    a, b, label = a.to(device), b.to(device), label.to(device)
                    distence = mynet(a, b)
                    distence = distence.squeeze(1)
                    loss = criteon(distence, label.float())
                    acc = self.binary_acc(distence, label.float()).item()
                    avg_acc.append(acc)
                    avg_loss.append(loss.item())
                    if batch_id>50:
                        break
            avg_acc = np.array(avg_acc).mean()
            avg_loss = np.array(avg_loss).mean()
            print('>>test acc:', avg_acc)
            print(">>test loss:", avg_loss)
            return (avg_acc, avg_loss)
    
        def run_train(self):
            model = SiameseLSTM(128).to(self.device)
            max_acc = 0
            train_iter = self.get_loader(self.train_path)
            val_iter = self.get_loader(self.val_path)
            optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
    
            for epoch in range(self.epoch):
                self.train(model, train_iter, optimizer, self.criterion, epoch, self.device)
                eval_acc, eval_loss = self.eval(model, val_iter, self.criterion, epoch, self.device)
                if eval_acc > max_acc:
                    print("save model")
                    torch.save(model.state_dict(), self.check_point)
                    max_acc = eval_acc
    
    if __name__ == '__main__':
        RunLSTM().run_train()

     五、运行结果

    train acc: 0.779375
    train loss 0.5091823364257813
    >>test acc: 0.7703124992549419
    >>test loss: 0.5185132250189781
    轮数: 23 batch:  0 训练损失: 0.6139101982116699 准确率: 0.671875
    轮数: 23 batch:  100 训练损失: 0.6397958397865295 准确率: 0.703125
    轮数: 23 batch:  200 训练损失: 0.6126863360404968 准确率: 0.71875
    轮数: 23 batch:  300 训练损失: 0.4715595543384552 准确率: 0.8125
    轮数: 23 batch:  400 训练损失: 0.5854585766792297 准确率: 0.734375
    轮数: 23 batch:  500 训练损失: 0.4749883711338043 准确率: 0.78125
    轮数: 23 batch:  600 训练损失: 0.4674433469772339 准确率: 0.796875
    轮数: 23 batch:  700 训练损失: 0.5099883079528809 准确率: 0.765625
  • 相关阅读:
    sqlHelp.java
    IIS7.0 检测到在集成的托管管道模式下不适用的ASP.NET设置 的解决方法
    [转]安装程序在安装此软件包时遇到一个错误,这可能表示此软件包有错。错误码是29506
    单表中的sql语句
    网页::::无法访问请求的页面,因为该页的相关配置数据无效。
    <authentication mode="Windows"/>
    无法从传输连接中读取数据: 远程主机强迫关闭了一个现有的连接。这个错误很难判断错在哪里,刚开……
    [转] css实现透明度(兼容IE6、IE7、Firefox2.0)
    企业微信机器人消息发送
    阴阳历自动转换工具函数
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14773081.html
Copyright © 2011-2022 走看看