一、定义数据加载
my_dataset.py
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, texta, textb, label):
self.texta = texta
self.textb = textb
self.label = label
def __getitem__(self, item):
texta = self.texta[item]
textb = self.textb[item]
label = self.label[item]
return texta, textb, label
def __len__(self):
return len(self.texta)
二、定义词嵌入
my_word2vec.py
from gensim.models.fasttext import FastText
import torch
import numpy as np
import os
class WordEmbedding(object):
def __init__(self):
parent_path = os.path.split(os.path.realpath(__file__))[0]
self.root = parent_path[:parent_path.find("models")] # E:personassemantics
self.word_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "word_fasttext.model")
self.char_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "char_fasttext.model")
self.model = FastText.load(self.char_fasttext)
def sentenceTupleToEmbedding(self, data1, data2):
aCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data1])
bCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data2])
maxLen = max(aCutListMaxLen,bCutListMaxLen)
seq_len = maxLen
a = self.sqence_vec(data1, seq_len) #batch_size, sqence, embedding
b = self.sqence_vec(data2, seq_len)
return torch.FloatTensor(a), torch.FloatTensor(b)
def sqence_vec(self, data, seq_len):
data_a_vec = []
for sequence_a in data:
sequence_vec = [] # sequence * 128
for word_a in list(str(sequence_a)):
if word_a in self.model.wv:
sequence_vec.append(self.model.wv[word_a])
sequence_vec = np.array(sequence_vec)
add = np.zeros((seq_len - sequence_vec.shape[0], 128))
sequenceVec = np.vstack((sequence_vec, add))
data_a_vec.append(sequenceVec)
a_vec = np.array(data_a_vec)
return a_vec
if __name__ == '__main__':
word = WordEmbedding()
data1 = ("浙江杭州富阳区银湖街黄先生的外卖","浙江杭州富阳区银湖街黄先生的外卖")
data2 = ("富阳区浙江富阳区银湖街道新常村","浙江杭州富阳区银湖街黄先生的外卖")
a, b = word.sentenceTupleToEmbedding(data1, data2)
print(a.shape)
print(b)
三、定义模型
my_lstm.py
import torch
from torch import nn
class SiameseLSTM(nn.Module):
def __init__(self, input_size):
super(SiameseLSTM, self).__init__()
self.lstm = nn.LSTM(input_size=input_size, hidden_size=10, num_layers=1, batch_first=True)
self.fc = nn.Sequential(
nn.Linear(20, 1),
)
def forward(self, data1, data2):
out1, (h1, c1) = self.lstm(data1)
out2, (h2, c2) = self.lstm(data2)
pre1 = out1[:, -1, :]
pre2 = out2[:, -1, :]
pre = torch.cat([pre1, pre2], dim=1)
out = self.fc(pre)
return out
四、定义运行
run__lstm.py
import torch
import os
from torch.utils.data import DataLoader
from my_dataset import MyDataset
import pandas as pd
import numpy as np
from my_lstm import SiameseLSTM
import torch.nn as nn
from my_word2vec import WordEmbedding
class RunLSTM():
def __init__(self):
self.learning_rate = 0.001
self.device = torch.device("cpu")
parent_path = os.path.split(os.path.realpath(__file__))[0]
self.root = parent_path[:parent_path.find("models")] # E:personassemantics
self.train_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "train.csv")
self.val_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "val.csv")
self.test_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "test.csv")
self.batch_size =64
self.epoch = 50
self.criterion = nn.BCEWithLogitsLoss().to(self.device)
self.word = WordEmbedding()
self.check_point = os.path.join(self.root, "checkpoints", "char_bilstm", "char_bilstm.pth")
def get_loader(self, path):
data = pd.read_csv(path, sep=" ")
d1, d2, y = data["s1"], data["s2"], list(data["y"])
dataset = MyDataset(d1, d2, torch.LongTensor(y))
data_iter = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
return data_iter
def binary_acc(self, preds, y):
preds = torch.round(torch.sigmoid(preds))
correct = torch.eq(preds, y).float()
acc = correct.sum() / len(correct)
return acc
def train(self, mynet, train_iter, optimizer, criterion, epoch, device):
avg_acc = []
avg_loss = []
mynet.train()
for batch_id, (data1, data2, label) in enumerate(train_iter):
try:
a, b = self.word.sentenceTupleToEmbedding(data1, data2)
except Exception as e:
print("错误")
a, b, label = a.to(device), b.to(device), label.to(device)
distence = mynet(a, b)
distence = distence.squeeze(1)
loss = criterion(distence, label.float())
acc = self.binary_acc(distence, label.float()).item()
avg_acc.append(acc)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_id % 100 == 0:
print("轮数:", epoch, "batch: ", batch_id, "训练损失:", loss.item(), "准确率:", acc)
avg_loss.append(loss.item())
avg_acc = np.array(avg_acc).mean()
avg_loss = np.array(avg_loss).mean()
print('train acc:', avg_acc)
print("train loss", avg_loss)
def eval(self, mynet, test_iter, criteon, epoch, device):
mynet.eval()
avg_acc = []
avg_loss = []
with torch.no_grad():
for batch_id, (data1, data2, label) in enumerate(test_iter):
try:
a, b = self.word.sentenceTupleToEmbedding(data1, data2)
except Exception as e:
continue
a, b, label = a.to(device), b.to(device), label.to(device)
distence = mynet(a, b)
distence = distence.squeeze(1)
loss = criteon(distence, label.float())
acc = self.binary_acc(distence, label.float()).item()
avg_acc.append(acc)
avg_loss.append(loss.item())
if batch_id>50:
break
avg_acc = np.array(avg_acc).mean()
avg_loss = np.array(avg_loss).mean()
print('>>test acc:', avg_acc)
print(">>test loss:", avg_loss)
return (avg_acc, avg_loss)
def run_train(self):
model = SiameseLSTM(128).to(self.device)
max_acc = 0
train_iter = self.get_loader(self.train_path)
val_iter = self.get_loader(self.val_path)
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
for epoch in range(self.epoch):
self.train(model, train_iter, optimizer, self.criterion, epoch, self.device)
eval_acc, eval_loss = self.eval(model, val_iter, self.criterion, epoch, self.device)
if eval_acc > max_acc:
print("save model")
torch.save(model.state_dict(), self.check_point)
max_acc = eval_acc
if __name__ == '__main__':
RunLSTM().run_train()
五、运行结果
train acc: 0.779375
train loss 0.5091823364257813
>>test acc: 0.7703124992549419
>>test loss: 0.5185132250189781
轮数: 23 batch: 0 训练损失: 0.6139101982116699 准确率: 0.671875
轮数: 23 batch: 100 训练损失: 0.6397958397865295 准确率: 0.703125
轮数: 23 batch: 200 训练损失: 0.6126863360404968 准确率: 0.71875
轮数: 23 batch: 300 训练损失: 0.4715595543384552 准确率: 0.8125
轮数: 23 batch: 400 训练损失: 0.5854585766792297 准确率: 0.734375
轮数: 23 batch: 500 训练损失: 0.4749883711338043 准确率: 0.78125
轮数: 23 batch: 600 训练损失: 0.4674433469772339 准确率: 0.796875
轮数: 23 batch: 700 训练损失: 0.5099883079528809 准确率: 0.765625