一.目的
利用pytorch实现前馈网络意图分类,实现一个简单的chatbot。
二.数据
数据为英文数据,如下:
{'intents': [{'tag': 'greeting', 'patterns': ['Hi there', 'How are you', 'Is anyone there?', 'Hey', 'Hola', 'Hello', 'Good day'], 'responses': ['Hello, thanks for asking', 'Good to see you again', 'Hi there, how can I help?'], 'context': ['']}, {'tag': 'goodbye', 'patterns': ['Bye', 'See you later', 'Goodbye', 'Nice chatting to you, bye', 'Till next time'], 'responses': ['See you!', 'Have a nice day', 'Bye! Come back again soon.'], 'context': ['']}, {'tag': 'thanks', 'patterns': ['Thanks', 'Thank you', "That's helpful", 'Awesome, thanks', 'Thanks for helping me'], 'responses': ['Happy to help!', 'Any time!', 'My pleasure'], 'context': ['']}, {'tag': 'noanswer', 'patterns': [], 'responses': ["Sorry, can't understand you", 'Please give me more info', 'Not sure I understand'], 'context': ['']}, {'tag': 'options', 'patterns': ['How you could help me?', 'What you can do?', 'What help you provide?', 'How you can be helpful?', 'What support is offered'], 'responses': ['I can guide you through Adverse drug reaction list, Blood pressure tracking, Hospitals and Pharmacies', 'Offering support for Adverse drug reaction, Blood pressure, Hospitals and Pharmacies'], 'context': ['']}, {'tag': 'adverse_drug', 'patterns': ['How to check Adverse drug reaction?', 'Open adverse drugs module', 'Give me a list of drugs causing adverse behavior', 'List all drugs suitable for patient with adverse reaction', 'Which drugs dont have adverse reaction?'], 'responses': ['Navigating to Adverse drug reaction module'], 'context': ['']}, {'tag': 'blood_pressure', 'patterns': ['Open blood pressure module', 'Task related to blood pressure', 'Blood pressure data entry', 'I want to log blood pressure results', 'Blood pressure data management'], 'responses': ['Navigating to Blood Pressure module'], 'context': ['']}, {'tag': 'blood_pressure_search', 'patterns': ['I want to search for blood pressure result history', 'Blood pressure for patient', 'Load patient blood pressure result', 'Show blood pressure results for patient', 'Find blood pressure results by ID'], 'responses': ['Please provide Patient ID', 'Patient ID?'], 'context': ['search_blood_pressure_by_patient_id']}, {'tag': 'search_blood_pressure_by_patient_id', 'patterns': [], 'responses': ['Loading Blood pressure result for Patient'], 'context': ['']}, {'tag': 'pharmacy_search', 'patterns': ['Find me a pharmacy', 'Find pharmacy', 'List of pharmacies nearby', 'Locate pharmacy', 'Search pharmacy'], 'responses': ['Please provide pharmacy name'], 'context': ['search_pharmacy_by_name']}, {'tag': 'search_pharmacy_by_name', 'patterns': [], 'responses': ['Loading pharmacy details'], 'context': ['']}, {'tag': 'hospital_search', 'patterns': ['Lookup for hospital', 'Searching for hospital to transfer patient', 'I want to search hospital data', 'Hospital lookup for patient', 'Looking up hospital details'], 'responses': ['Please provide hospital name or location'], 'context': ['search_hospital_by_params']}, {'tag': 'search_hospital_by_params', 'patterns': [], 'responses': ['Please provide hospital type'], 'context': ['search_hospital_by_type']}, {'tag': 'search_hospital_by_type', 'patterns': [], 'responses': ['Loading hospital details'], 'context': ['']}]}
其中,tag表示意图类别,patterns表示类别下的样本,responses为意图识别后的回应语句。
三.程序
完成程序和数据见(https://github.com/jiangnanboy/chatbot)
import torch import torch.nn as nn import torch.nn.functional as F import pickle import numpy as np import os import json import random intent_json_path = os.path.join(os.getcwd(), "intents.json") with open(intent_json_path, 'r', encoding='utf-8') as f: intents = json.load(f) words_path = os.path.join(os.getcwd(), "words.pkl") with open(words_path, 'rb') as f_words: words = pickle.load(f_words) classes_path = os.path.join(os.getcwd(), "classes.pkl") with open(classes_path, 'rb') as f_classes: classes = pickle.load(f_classes) classes_index_path = os.path.join(os.getcwd(), "classes_index.pkl") with open(classes_index_path, 'rb') as f_classes_index: classes_index = pickle.load(f_classes_index) index_classes = dict(zip(classes_index.values(), classes_index.keys())) print('index_classes:{}'.format(index_classes)) class classifyModel(nn.Module): def __init__(self): super(classifyModel, self).__init__() self.model = nn.Sequential( nn.Linear(len(words), 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(64, len(classes))) def forward(self, x): out = self.model(x) return out model = classifyModel() model_path = os.path.join(os.getcwd(), "chatbot_model.h5") model.load_state_dict(torch.load(model_path))
import nltk from nltk.stem import WordNetLemmatizer lemmatizer = WordNetLemmatizer() def clean_up_sentence(sentence): sentence_words = nltk.word_tokenize(sentence) #分词 sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words] #词干 return sentence_words def bow(sentence, words, show_detail = True): sentence_words = clean_up_sentence(sentence) #词袋 bag = [0] * len(words) for s in sentence_words: for i,w in enumerate(words): if w == s: bag[i] = 1 #词在词典中 if show_detail: print("found in bag:{}".format(w)) return [bag] def predict_class(sentence, model): sentence_bag = bow(sentence, words, False) model.eval() outputs = model(torch.FloatTensor(sentence_bag)) print('outputs:{}'.format(outputs)) predicted_prob,predicted_index = torch.max(F.softmax(outputs, 1), 1)#预测最大类别的概率与索引 print('softmax_prob:{}'.format(predicted_prob)) print('softmax_index:{}'.format(predicted_index)) results = [] results.append({'intent':index_classes[predicted_index.detach().numpy()[0]], 'prob':str(predicted_prob.detach().numpy()[0])}) print('result:{}'.format(results)) return results def get_response(predict_result, intents_json): tag = predict_result[0]['intent'] list_of_intents = intents_json['intents'] for i in list_of_intents: if(i['tag'] == tag): result = random.choice(i['responses']) break return result def chatbot_response(text): predict_result = predict_class(text, model) res = get_response(predict_result, intents) return res print(chatbot_response("Lookup for hospital"))
import tkinter from tkinter import * def send(): msg = EntryBox.get("1.0",'end-1c').strip() EntryBox.delete("0.0",END) if msg != '': ChatLog.config(state=NORMAL) ChatLog.insert(END, "你: " + msg + ' ') ChatLog.config(foreground="#442265", font=("Verdana", 12 )) res = chatbot_response(msg) ChatLog.insert(END, "机器人: " + res + ' ') ChatLog.config(state=DISABLED) ChatLog.yview(END) base = Tk() base.title("Hello") base.geometry("400x500") base.resizable(width=FALSE, height=FALSE) #Create Chat window ChatLog = Text(base, bd=0, bg="white", height="8", width="50", font="Arial",) ChatLog.config(state=DISABLED) #Bind scrollbar to Chat window scrollbar = Scrollbar(base, command=ChatLog.yview, cursor="heart") ChatLog['yscrollcommand'] = scrollbar.set #Create Button to send message SendButton = Button(base, font=("Verdana",12,'bold'), text="发送", width="12", height=5, bd=0, bg="#32de97", activebackground="#3c9d9b",fg='#ffffff', command= send ) #Create the box to enter message EntryBox = Text(base, bd=0, bg="white",width="29", height="5", font="Arial") #EntryBox.bind("<Return>", send) #Place all components on the screen scrollbar.place(x=376,y=6, height=386) ChatLog.place(x=6,y=6, height=386, width=370) EntryBox.place(x=128, y=401, height=90, width=265) SendButton.place(x=6, y=401, height=90) base.mainloop()
四.结果