zoukankan      html  css  js  c++  java
  • 三、word2vec + siameseLSTM

    一、词嵌入

    import jieba
    from gensim.models import Word2Vec
    import torch
    import gensim
    import numpy as np
    model = gensim.models.KeyedVectors.load_word2vec_format('model\word2vec.bin', binary=True)
    class WordEmbedding(object):
        def __init__(self):
            pass
        def sentenceTupleToEmbedding(self, data1, data2):
            aCutListMaxLen = max([len(list(jieba.cut(sentence_a))) for sentence_a in data1])
            bCutListMaxLen = max([len(list(jieba.cut(sentence_a))) for sentence_a in data2])
            maxLen = max(aCutListMaxLen,bCutListMaxLen)
            seq_len = maxLen
            a = self.sqence_vec(data1, seq_len) #batch_size, sqence, embedding
            b = self.sqence_vec(data2, seq_len)
            return torch.FloatTensor(a), torch.FloatTensor(b)
        def sqence_vec(self, data, seq_len):
            data_a_vec = []
            for sequence_a in data:
                sequence_vec = []  # sequence * 128
                for word_a in jieba.cut(sequence_a):
                    if word_a in model:
                        sequence_vec.append(model[word_a])
                sequence_vec = np.array(sequence_vec)
                add = np.zeros((seq_len - sequence_vec.shape[0], 128))
                sequenceVec = np.vstack((sequence_vec, add))
                data_a_vec.append(sequenceVec)
            a_vec = np.array(data_a_vec)
            return a_vec

    二、dataSet设置

    import torch.utils.data as data
    import torch
    
    class DatasetIterater(data.Dataset):
        def __init__(self, texta, textb, label):
            self.texta = texta
            self.textb = textb
            self.label = label
        def __getitem__(self, item):
            texta = self.texta[item]
            textb = self.textb[item]
            label = self.label[item]
            return texta, textb, label
        def __len__(self):
            return len(self.texta)

    三、SiameseLSTM

    import torch
    from torch import nn
    class SiameseLSTM(nn.Module):
        def __init__(self, input_size):
            super(SiameseLSTM, self).__init__()
            self.lstm = nn.LSTM(input_size=input_size, hidden_size=10, num_layers=4, batch_first=True)
            self.fc = nn.Linear(10, 1)
        def forward(self, data1, data2):
            out1, (h1, c1) = self.lstm(data1)
            out2, (h2, c2) = self.lstm(data2)
            pre1 = out1[:, -1, :]
            pre2 = out2[:, -1, :]
            dis = torch.abs(pre1 - pre2)
            out = self.fc(dis)
            return out

    四、mainProcess

    import torch
    from torch import nn
    from torch.utils.data import DataLoader
    import pandas as pd
    from datasetIterater import DatasetIterater
    import jieba
    from wordEmbedding import WordEmbedding
    from siameseLSTM import  SiameseLSTM
    
    learning_rate = 0.001
    train_texta = pd.read_csv("data/POI/negtive.csv")["address_1"]
    train_textb = pd.read_csv("data/POI/negtive.csv")["address_2"]
    train_label = pd.read_csv("data/POI/negtive.csv")["tag"]
    train_data = DatasetIterater(train_texta,train_textb,train_label)
    train_iter = DataLoader(dataset=train_data,batch_size=32,shuffle=True)
    
    net = SiameseLSTM(128)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    word = WordEmbedding()
    train_loss = []
    for epoch in range(10):
        for batch_id, (data1, data2, label) in enumerate(train_iter):
            a, b = word.sentenceTupleToEmbedding(data1, data2)
            distence = net(a, b)
            loss = criterion(distence, label.float().unsqueeze(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            if batch_id % 10==0:
                print(epoch, batch_id, loss.item())
  • 相关阅读:
    Mysql Got a packet bigger than 'max_allowed_packet' bytes
    Git之IDEA集成Git更新项目Update Type选项解释
    IDEA获取GIT仓库时更新类型update type的选择
    git merge和git rebase的区别
    git merge和git merge --no-ff的区别
    Git中fetch和pull命令的区别
    git官网下载太慢解决方法
    IDEA执行Thread.activeCount() = 2的问题
    k8s 常见错误汇总
    Axure9破解
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14164118.html
Copyright © 2011-2022 走看看