zoukankan      html  css  js  c++  java
  • PyQt训练BP模型时,显示waiting动图(多线程)

    1、实现效果

    2、相关代码

    实现BP训练模型的线程类

     1 class WorkThread(QtCore.QThread):
     2     finish_trigger = QtCore.pyqtSignal()  # 关闭waiting_gif
     3     result_trigger = QtCore.pyqtSignal(pd.Series)  # 传递预测结果信号
     4     evaluate_trigger = QtCore.pyqtSignal(list)  # 传递正确率信号
     5 
     6     def __int__(self):
     7         super(WorkThread, self).__init__()
     8 
     9     def init(self, dataset, feature, label, info):
    10         self.dataset = dataset
    11         self.feature = feature
    12         self.label = label
    13         self.info = info
    14 
    15     # 可以认为,run()函数就是新的线程需要执行的代码
    16     def run(self):
    17         self.BP()
    18 
    19     def BP(self):
    20         """
    21         BP神经网络,返回标签的预测数据
    22         :param parent:
    23         :param dataset:
    24         :param feature:
    25         :param label:
    26         :param info:
    27         :return:
    28         """
    29         dataset = self.dataset
    30         feature = self.feature
    31         label = self.label
    32         info = self.info
    33 
    34         input_dim = len(feature)
    35         data_x = dataset[feature]  # 特征数据
    36         data_y = dataset[label]  # 标签数据
    37 
    38         x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=info[0][3])
    39     
    40         # **********************建立一个简单BP神经网络模型*********************************
    41         self.model = Sequential()  # 声明一个顺序模型
    42         count = len(info)
    43         for i in range(1, count-1):
    44             if i == 1:
    45                 self.model.add(Dense(info[i][0], activation=info[i][1], input_dim=input_dim, kernel_initializer=info[i][2]))  # 输入层,Dense表示BP层
    46             else:
    47                 self.model.add(Dense(info[i][0], activation=info[i][1], kernel_initializer=info[i][2]))
    48 
    49         # 添加输出层
    50         self.model.add(Dense(info[count-1][0], activation=info[count-1][1], kernel_initializer=info[count-1][2]))
    51 
    52         sgd = SGD(lr=info[0][0], decay=1e-6, momentum=0.9, nesterov=True)
    53         self.model.compile(loss='binary_crossentropy',  optimizer=sgd,  metrics=['accuracy'])  # 编译模型
    54 
    55         self.model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=info[0][1], batch_size=info[0][2])  # 训练模型1000次
    56 
    57         scores_train = self.model.evaluate(x_train, y_train, batch_size=10)
    58         scores_test = self.model.evaluate(x_test, y_test, batch_size=10)
    59         scores = self.model.evaluate(data_x, data_y, batch_size=10)
    60 
    61         self.finish_trigger.emit()         # 循环完毕后发出信号
    62         list = [scores_train[1]*100, scores_test[1]*100, scores[1]*100]
    63         self.evaluate_trigger.emit(list)
    64         result = pd.Series(self.model.predict(data_x).T[0])
    65         result.name = '预测(BP)'
    66         self.result_trigger.emit(result)
    67         K.clear_session()  # 反复调用model 模型
    68 
    69     def save_model(self, save_dir):
    70         self.model.save(save_dir)  # 保存模型

    GUI显示代码(部分):

     1 class MainWindow(QtGui.QMainWindow):
     2     save_dir_signal = QtCore.pyqtSignal(str)  # 传递保存目录信号
     3 
     4 def show_evaluate_result(self, evaluate_result):
     5         help = QtGui.QMessageBox.information(self, '评价结果',
     6                                              "训练集正确率:  %.2f%%
    测试集正确率:  %.2f%%
    数据集正确率:  %.2f%%" %
     7                                              (evaluate_result[0], evaluate_result[1], evaluate_result[2]),
     8                                              QtGui.QMessageBox.Yes)
     9 
    10         self.pop_save_dir()
    11 
    12     def pop_save_dir(self):
    13         msg = QtGui.QMessageBox.information(self, '提示', '是否保存模型?', QtGui.QMessageBox.Yes | QtGui.QMessageBox.No)
    14         if msg == QtGui.QMessageBox.Yes:
    15                 save_dir = QtGui.QFileDialog.getSaveFileName(self, '选择保存目录', 'C:\Users\fuqia\Desktop')
    16 
    17                 if save_dir != '':
    18                     save_dir = save_dir + '.model'
    19                     self.save_dir_signal.emit(save_dir)
    20 
    21     def show_bp_result(self, result):
    22 
    23         self.predict_data = result
    24         TableWidgetDeal.add_predict_data(self.table, result)
    25 
    26     def waiting_label_close(self):
    27         self.label.close()
    28 
    29     def show_waiting(self):
    30         self.label = QtGui.QLabel(self)
    31         self.label.setFixedSize(640, 480)  # 不加的话有问题???
    32         self.label.setWindowFlags(QtCore.Qt.FramelessWindowHint)  # 无边框
    33         self.label.setAttribute(QtCore.Qt.WA_TranslucentBackground)  # 背景透明
    34 
    35         screen = QtGui.QDesktopWidget().screenGeometry()
    36         size = self.label.geometry()
    37         # 如果是self.label.move((screen.width() - size.width()) / 2 , (screen.height() - size.height()) / 2)无法居中
    38         self.label.move((screen.width() - size.width()) / 2 + 240, (screen.height() - size.height()) / 2)
    39 
    40         # 打开gif文件
    41         movie = QtGui.QMovie("./Icon/waiting.gif")
    42         # 设置cacheMode为CacheAll时表示gif无限循环,注意此时loopCount()返回-1
    43         movie.setCacheMode(QtGui.QMovie.CacheAll)
    44         # 播放速度
    45         movie.setSpeed(100)
    46         self.label.setMovie(movie)
    47         # 开始播放,对应的是movie.start()
    48         movie.start()
    49         self.label.show()
    50         q = QtCore.QEventLoop()
    51         q.exec_()
    1 w = WorkThread()
    2 w.init(self.object.data_set, feature, label, self.bp_ui.bp_info)
    3 w.start()
    4 w.finish_trigger.connect(self.waiting_label_close)
    5 w.result_trigger.connect(self.show_bp_result)
    6 w.evaluate_trigger.connect(self.show_evaluate_result)
    7 self.save_dir_signal.connect(w.save_model)
    8 self.show_waiting()
  • 相关阅读:
    Python安装
    solr集群solrCloud的搭建
    redis单机及其集群的搭建
    maven实现tomcat热部署
    maven发布时在不同的环境使用不同的配置文件
    nexus 的使用及maven的配置
    java 自定义注解以及获得注解的值
    Jenkins学习之——(4)Email Extension Plugin插件的配置与使用
    Jenkins学习之——(3)将项目发送到tomcat
    注意Tengine(Nginx) proxy_pass之后的"/"
  • 原文地址:https://www.cnblogs.com/fuqia/p/9191696.html
Copyright © 2011-2022 走看看