zoukankan      html  css  js  c++  java
  • 利用pytorch实现前馈网络分类的chatbot

    一.目的

    利用pytorch实现前馈网络意图分类,实现一个简单的chatbot。

    二.数据

    数据为英文数据,如下:

    {'intents': [{'tag': 'greeting', 'patterns': ['Hi there', 'How are you', 'Is anyone there?', 'Hey', 'Hola', 'Hello', 'Good day'], 'responses': ['Hello, thanks for asking', 'Good to see you again', 'Hi there, how can I help?'], 'context': ['']}, {'tag': 'goodbye', 'patterns': ['Bye', 'See you later', 'Goodbye', 'Nice chatting to you, bye', 'Till next time'], 'responses': ['See you!', 'Have a nice day', 'Bye! Come back again soon.'], 'context': ['']}, {'tag': 'thanks', 'patterns': ['Thanks', 'Thank you', "That's helpful", 'Awesome, thanks', 'Thanks for helping me'], 'responses': ['Happy to help!', 'Any time!', 'My pleasure'], 'context': ['']}, {'tag': 'noanswer', 'patterns': [], 'responses': ["Sorry, can't understand you", 'Please give me more info', 'Not sure I understand'], 'context': ['']}, {'tag': 'options', 'patterns': ['How you could help me?', 'What you can do?', 'What help you provide?', 'How you can be helpful?', 'What support is offered'], 'responses': ['I can guide you through Adverse drug reaction list, Blood pressure tracking, Hospitals and Pharmacies', 'Offering support for Adverse drug reaction, Blood pressure, Hospitals and Pharmacies'], 'context': ['']}, {'tag': 'adverse_drug', 'patterns': ['How to check Adverse drug reaction?', 'Open adverse drugs module', 'Give me a list of drugs causing adverse behavior', 'List all drugs suitable for patient with adverse reaction', 'Which drugs dont have adverse reaction?'], 'responses': ['Navigating to Adverse drug reaction module'], 'context': ['']}, {'tag': 'blood_pressure', 'patterns': ['Open blood pressure module', 'Task related to blood pressure', 'Blood pressure data entry', 'I want to log blood pressure results', 'Blood pressure data management'], 'responses': ['Navigating to Blood Pressure module'], 'context': ['']}, {'tag': 'blood_pressure_search', 'patterns': ['I want to search for blood pressure result history', 'Blood pressure for patient', 'Load patient blood pressure result', 'Show blood pressure results for patient', 'Find blood pressure results by ID'], 'responses': ['Please provide Patient ID', 'Patient ID?'], 'context': ['search_blood_pressure_by_patient_id']}, {'tag': 'search_blood_pressure_by_patient_id', 'patterns': [], 'responses': ['Loading Blood pressure result for Patient'], 'context': ['']}, {'tag': 'pharmacy_search', 'patterns': ['Find me a pharmacy', 'Find pharmacy', 'List of pharmacies nearby', 'Locate pharmacy', 'Search pharmacy'], 'responses': ['Please provide pharmacy name'], 'context': ['search_pharmacy_by_name']}, {'tag': 'search_pharmacy_by_name', 'patterns': [], 'responses': ['Loading pharmacy details'], 'context': ['']}, {'tag': 'hospital_search', 'patterns': ['Lookup for hospital', 'Searching for hospital to transfer patient', 'I want to search hospital data', 'Hospital lookup for patient', 'Looking up hospital details'], 'responses': ['Please provide hospital name or location'], 'context': ['search_hospital_by_params']}, {'tag': 'search_hospital_by_params', 'patterns': [], 'responses': ['Please provide hospital type'], 'context': ['search_hospital_by_type']}, {'tag': 'search_hospital_by_type', 'patterns': [], 'responses': ['Loading hospital details'], 'context': ['']}]}

    其中,tag表示意图类别,patterns表示类别下的样本,responses为意图识别后的回应语句。

    三.程序

    完成程序和数据见(https://github.com/jiangnanboy/chatbot)

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import pickle
    import numpy as np
    import os
    import json
    import random
    
    intent_json_path = os.path.join(os.getcwd(), "intents.json")
    with open(intent_json_path, 'r', encoding='utf-8') as f:
        intents = json.load(f)
        
    words_path = os.path.join(os.getcwd(), "words.pkl")
    with open(words_path, 'rb') as f_words:
        words = pickle.load(f_words)
        
    classes_path = os.path.join(os.getcwd(), "classes.pkl")
    with open(classes_path, 'rb') as f_classes:
        classes = pickle.load(f_classes)
        
    classes_index_path = os.path.join(os.getcwd(), "classes_index.pkl")
    with open(classes_index_path, 'rb') as f_classes_index:
        classes_index = pickle.load(f_classes_index)
    index_classes = dict(zip(classes_index.values(), classes_index.keys()))
    print('index_classes:{}'.format(index_classes))
    class classifyModel(nn.Module):
        
        def __init__(self):
            super(classifyModel, self).__init__()
            self.model = nn.Sequential(
                    nn.Linear(len(words), 128),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(128, 64),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(64, len(classes)))
        def forward(self, x):
                out = self.model(x)
                return out
            
    model = classifyModel()
    model_path = os.path.join(os.getcwd(), "chatbot_model.h5")
    model.load_state_dict(torch.load(model_path))
    import nltk
    from nltk.stem import WordNetLemmatizer
    
    lemmatizer = WordNetLemmatizer()
    
    def clean_up_sentence(sentence):
        sentence_words = nltk.word_tokenize(sentence) #分词
        sentence_words = [lemmatizer.lemmatize(word.lower()) for word in sentence_words] #词干
        return sentence_words
    
    def bow(sentence, words, show_detail = True):
        sentence_words = clean_up_sentence(sentence)
        #词袋
        bag = [0] * len(words)
        for s in sentence_words:
            for i,w in enumerate(words):
                if w == s:
                    bag[i] = 1 #词在词典中
                if show_detail:
                    print("found in bag:{}".format(w))
        return [bag]
    
    def predict_class(sentence, model):
        sentence_bag = bow(sentence, words, False)
        model.eval()
        outputs = model(torch.FloatTensor(sentence_bag))
        print('outputs:{}'.format(outputs))
        predicted_prob,predicted_index = torch.max(F.softmax(outputs, 1), 1)#预测最大类别的概率与索引
        print('softmax_prob:{}'.format(predicted_prob))
        print('softmax_index:{}'.format(predicted_index))
        results = []
        results.append({'intent':index_classes[predicted_index.detach().numpy()[0]], 'prob':str(predicted_prob.detach().numpy()[0])})
        print('result:{}'.format(results))
        return results
     
    def get_response(predict_result, intents_json):
        tag = predict_result[0]['intent']
        list_of_intents = intents_json['intents']
        for i in list_of_intents:
            if(i['tag'] == tag):
                result = random.choice(i['responses'])
                break
        return result
    
    def chatbot_response(text):
        predict_result = predict_class(text, model)
        res = get_response(predict_result, intents)
        return res
    print(chatbot_response("Lookup for hospital"))
    import tkinter
    from tkinter import *
    
    def send():
        msg = EntryBox.get("1.0",'end-1c').strip()
        EntryBox.delete("0.0",END)
        if msg != '':
                ChatLog.config(state=NORMAL)
                ChatLog.insert(END, "你: " + msg + '
    
    ')
                ChatLog.config(foreground="#442265", font=("Verdana", 12 ))
                res = chatbot_response(msg)
                ChatLog.insert(END, "机器人: " + res + '
    
    ')
                ChatLog.config(state=DISABLED)
                ChatLog.yview(END)
    base = Tk()
    base.title("Hello")
    base.geometry("400x500")
    base.resizable(width=FALSE, height=FALSE)
    #Create Chat window
    ChatLog = Text(base, bd=0, bg="white", height="8", width="50", font="Arial",)
    ChatLog.config(state=DISABLED)
    #Bind scrollbar to Chat window
    scrollbar = Scrollbar(base, command=ChatLog.yview, cursor="heart")
    ChatLog['yscrollcommand'] = scrollbar.set
    #Create Button to send message
    SendButton = Button(base, font=("Verdana",12,'bold'), text="发送", width="12", height=5,
                        bd=0, bg="#32de97", activebackground="#3c9d9b",fg='#ffffff',
                        command= send )
    #Create the box to enter message
    EntryBox = Text(base, bd=0, bg="white",width="29", height="5", font="Arial")
    #EntryBox.bind("<Return>", send)
    #Place all components on the screen
    scrollbar.place(x=376,y=6, height=386)
    ChatLog.place(x=6,y=6, height=386, width=370)
    EntryBox.place(x=128, y=401, height=90, width=265)
    SendButton.place(x=6, y=401, height=90)
    base.mainloop()

    四.结果

  • 相关阅读:
    启明星门户网站Portal发布V4.5,并兼论部分功能的实现
    修改SQL数据库dbo所有者
    iphone& android 开发指南 http://mobile.tutsplus.com
    启明星会议室预定系统V5.0.0.0版本说明
    启明星Portal企业内部网站V4.3版 附演示地址 http://demo.dotnetcms.org
    在winform程序里实现最小化隐藏到windows右下角
    【门户网站】启明星Portal系统里,关于天气预报调用的说明
    获取客户端经纬度坐标
    修改表名或者列名SQL
    ER图
  • 原文地址:https://www.cnblogs.com/little-horse/p/14033132.html
Copyright © 2011-2022 走看看