zoukankan      html  css  js  c++  java
  • keras训练实例-python实现

    用keras训练模型并实时显示loss/acc曲线,(重要的事情说三遍:实时!实时!实时!)实时导出loss/acc数值(导出的方法就是实时把loss/acc等写到一个文本文件中,其他模块如前端调用时可直接读取文本文件),同时也涉及了plt画图方法

    ps:以下代码基于网上的一段程序修改完成,如有侵权,请联系我哈!

    上代码:

    from keras import Sequential, initializers, optimizers
    from keras.layers import Activation, Dense
    import numpy as np
    import pylab as pl
    from IPython import display
    from keras.callbacks import Callback
    from keras.datasets import mnist
    import keras
    from keras.layers import Conv2D, MaxPooling2D
    from keras.layers import Dense, Dropout, Flatten
    
    #定义回调函数的类,用于实时显示loss/acc曲线和导出loss/acc数值
    class DrawCallback(Callback):
        def __init__(self, runtime_plot=True): # 初始化
    
            self.init_loss = None
            self.init_val_loss = None
            self.init_acc = None
            self.init_val_acc = None
            self.runtime_plot = runtime_plot
            
            self.xdata = []
            self.ydata = []
            self.ydata2 = []
            self.ydata3 = []
            self.ydata4 = []
        def _plot(self, epoch=None):
            epochs = self.params.get("epochs")
            pl.subplot(121) #画第一个图,121表示纵向1个图,横向2个图,当前第1个图
            pl.ylim(0, int(self.init_loss*1.2)) #限制坐标轴范围
            pl.xlim(0, epochs)
            pl.plot(self.xdata, self.ydata,'r', label='loss') #xdata/ydata均为不断增长的一维数组,同时定义了线段颜色/类型/图例
            pl.plot(self.xdata, self.ydata2, 'b--', label='val_loss')
            pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs)) #坐标轴显示变化的标签
            pl.ylabel('Loss {:.4f}'.format(self.ydata[-1]))
            pl.legend() #显示图例,不加这个即便是定义图例了也没用
            pl.title('loss') #显示标题
            
            pl.subplot(122)
            pl.ylim(0, 1.2)
            pl.xlim(0, epochs)
            pl.plot(self.xdata, self.ydata3,'r', label='acc')
            pl.plot(self.xdata, self.ydata4, 'b--', label='val_acc')
            pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs))
            pl.ylabel('Loss {:.4f}'.format(self.ydata[-1]))
            pl.legend()
            pl.title('acc')
            
        def _runtime_plot(self, epoch):
            self._plot(epoch)
            #不断的清图
            display.clear_output(wait=True)
            display.display(pl.gcf())
            pl.gcf().clear()
            
        def plot(self):
            self._plot()
            pl.show() #显示窗口
        
        def on_epoch_end(self, epoch, logs = None): #更新xdata/ydata
            logs = logs or {}
    #         batch_size = self.params.get("batch_size")
            epochs = self.params.get("epochs") #获取训练相关数据
            loss = logs.get("loss")
            val_loss = logs.get("val_loss")
            acc = logs.get("acc")
            val_acc = logs.get("val_acc")
            
            epochs_str = str(epochs)[0:6] #为了写入txt,必须转为字符型,为了美观只保留小数点后4位
            loss_str = str(loss)[0:6]
            val_loss_str = str(val_loss)[0:6]
            acc_str = str(acc)[0:6]
            val_acc_str = str(val_acc)[0:6]
            
            f = open('logs_r/record.txt','a') #要用追加方式‘a’写入txt,所在行数就是当前迭代次数
            f.write('epochs:{}_loss:{}_val_loss:{}_acc:{}_val_acc{}'.format(epochs_str,loss_str,val_loss_str,acc_str,val_acc_str))
            f.write('
    ')
            f.close()
    
            if self.init_loss is None: #增加xdata/ydata内容
                self.init_loss = loss
                self.init_val_loss = val_loss
            self.xdata.append(epoch)
            self.ydata.append(loss)
            self.ydata2.append(val_loss)
            self.ydata3.append(acc)
            self.ydata4.append(val_acc)
            if self.runtime_plot:
                self._runtime_plot(epoch)
    
    # 下面开始构建keras需要的东西
    def viz_keras_fit(runtime_plot=False):
        d = DrawCallback(runtime_plot = runtime_plot) #实例化回调函数
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train = x_train.reshape(-1,28,28,1)
        x_test = x_test.reshape(-1,28,28,1)
        input_shape = (28,28,1)
        x_train = x_train/255
        x_test = x_test/255
        y_train = keras.utils.to_categorical(y_train,10)
        y_test = keras.utils.to_categorical(y_test,10)
        #为了减小计算量,减少了训练/测试数据
        x_train = x_train[0:600,:,:,:]
        x_test = x_test[0:100,:,:,:]
        y_train = y_train[0:600,:]
        y_test = y_test[0:100,:]
        
        model = Sequential() #实例化一个模型
        #接下来一顿操作,就是搭建网络
        model.add(Conv2D(filters=32, kernel_size=(3,3),
                    activation='relu', input_shape=input_shape,
                    name='conv1'))
        model.add(Conv2D(64,(3,3),activation='relu',name='conv2'))
        model.add(MaxPooling2D(pool_size=(2,2),name='pool2'))
        model.add(Dropout(0.25,name='dropout1'))
        model.add(Flatten(name='flat1'))
        model.add(Dense(128,activation='relu'))
        model.add(Dropout(0.5,name='dropout2'))
        model.add(Dense(10,activation='softmax',name='output'))
        #编译网络,同时定义了loss方法/优化方法/监测内容
        model.compile(loss=keras.losses.categorical_crossentropy,
                 optimizer=keras.optimizers.Adadelta(),
                 metrics=['accuracy'])
        #开始训练
        model.fit(x = x_train,
                 y = y_train,
                 epochs=30,
                 verbose=0, #当值为1时,会打印训练过程
                  validation_data=(x_test, y_test), #加入测试数据,不然有些数据时看不到的
                 callbacks=[d]) #指定回调函数
        return d
    

      

    最后运行:

    viz_keras_fit(runtime_plot=True) #调用函数

    显示结果:

  • 相关阅读:
    Permission denied (publickey). fatal: Could not read from remote repository.
    jQuery OCUpload一键上传文件
    org.apache.subversion.javahl.ClientException: Working copy is not up-to-date
    测开之路一百三十五:实现登录身份验证功能
    测开之路一百三十四:实现指定查找功能
    测开之路一百三十三:实现sql函数封装
    测开之路一百三十二:实现修改功能
    测开之路一百三十一:实现删除功能
    测开之路一百三十:实现前端到数据库交互(增和查)
    测开之路一百二十九:jinja2模板语法
  • 原文地址:https://www.cnblogs.com/niulang/p/11752914.html
Copyright © 2011-2022 走看看