zoukankan      html  css  js  c++  java
  • model.fit中的callbacks是做什么的

    model.fit中的callbacks是做什么的

    一、总结

    一句话总结:

    keras的callback参数可以帮助我们实现在训练过程中的适当时机被调用。实现实时保存训练模型以及训练参数。

    二、keras深度训练1:fit和callback

    转自或参考:keras深度训练1:fit和callback
    http://blog.csdn.net/github_36326955/article/details/79794288

    1. model.fit

    model.fit(
        self, 
        x, 
        y, 
        batch_size=32, 
        nb_epoch=10, 
        verbose=1, 
        callbacks=[], 
        validation_split=0.0, 
        validation_data=None, 
        shuffle=True, 
        class_weight=None, 
        sample_weight=None
    )

    其中:

    1. x为输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应于各个输入的numpy array。如果模型的每个输入都有名字,则可以传入一个字典,将输入名与其输入数据对应起来。
    2. y:标签,numpy array。如果模型有多个输出,可以传入一个numpy array的list。如果模型的输出拥有名字,则可以传入一个字典,将输出名与其标签对应起来。
    3. batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。
    4. nb_epoch:整数,训练的轮数,训练数据将会被遍历nb_epoch次。Keras中nb开头的变量均为”number of”的意思
    5. verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
    6. callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数
    7. validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。
    8. validation_data:形式为(X,y)或(X,y,sample_weights)的tuple,是指定的验证集。此参数将覆盖validation_spilt。
    9. shuffle:布尔值,表示是否在训练过程中每个epoch前随机打乱输入样本的顺序。请注意:这个shuffle并不是对整个数据集打乱顺序的,而是先split出训练集和验证集,然后对训练集进行shuffle。
    10. class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。
    11. sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode=’temporal’。12345678910111234567891011

    fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。

    2. callback

    keras的callback参数可以帮助我们实现在训练过程中的适当时机被调用。实现实时保存训练模型以及训练参数。

    2.1 ModelCheckpoint
    keras.callbacks.ModelCheckpoint(
        filepath, 
        monitor='val_loss', 
        verbose=0, 
        save_best_only=False, 
        save_weights_only=False, 
        mode='auto', 
        period=1
    )

    其中:
    1. filename:字符串,保存模型的路径
    2. monitor:需要监视的值
    3. verbose:信息展示模式,0或1
    4. save_best_only:当设置为True时,将只保存在验证集上性能最好的模型,一般我们都会设置为True.
    5. mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
    6. save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
    7. period:CheckPoint之间的间隔的epoch数

    2.2 EarlyStopping
    from keras.callbacksimport EarlyStopping 
    
    keras.callbacks.EarlyStopping(
        monitor='val_loss', 
        patience=0, 
        verbose=0, 
        mode='auto'
    )
    
    model.fit(X, y, validation_split=0.2, callbacks=[early_stopping])

    其中:
    1. monitor:需要监视的量
    2. patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。
    3. verbose:信息展示模式
    4. mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

    2.3 LearningRateSchedule

    学习率动态调整

    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', 
        factor=0.1, 
        patience=10, 
        verbose=0, 
        mode='auto', 
        epsilon=0.0001, 
        cooldown=0, 
        min_lr=0
    )

    其中:
    1. monitor:被监测的量
    2. factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少
    3. patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发
    4. mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。
    5. epsilon:阈值,用来确定是否进入检测值的“平原区”
    6. cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作
    7. min_lr:学习率的下限

    当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果


    自定义动态调整学习率:

    def step_decay(epoch):
        initial_lrate = 0.01
        drop = 0.5
        epochs_drop = 10.0
        lrate = initial_lrate * math.pow(drop,math.floor((1+epoch)/epochs_drop))
        return lrate
    lrate = LearningRateScheduler(step_decay)
    sgd = SGD(lr=0.0, momentum=0.9, decay=0.0, nesterov=False)
    model.fit(train_set_x, train_set_y, validation_split=0.1, nb_epoch=200, batch_size=256, callbacks=[lrate])
    

    具体可以参考这篇文章Using Learning Rate Schedules for Deep Learning Models in Python with Keras

    2.4 记录每一次epoch的训练/验证损失/准确度?

    Model.fit函数会返回一个 History 回调,该回调有一个属性history包含一个封装有连续损失/准确的lists。代码如下:

    hist = model.fit(X, y,validation_split=0.2)  
    print(hist.history)

    Keras输出的loss,val这些值如何保存到文本中去
    Keras中的fit函数会返回一个History对象,它的History.history属性会把之前的那些值全保存在里面,如果有验证集的话,也包含了验证集的这些指标变化情况,具体写法

    hist=model.fit(train_set_x,train_set_y,batch_size=256,shuffle=True,nb_epoch=nb_epoch,validation_split=0.1)
    with open('log_sgd_big_32.txt','w') as f:
        f.write(str(hist.history))
    2.5 TensorBoard
    
    from keras.callbacks import TensorBoard
    
    tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,
                              write_graph=True, write_images=False)
    # define model
    model.fit(X_train, Y_train,
              batch_size=batch_size,
              epochs=nb_epoch,
              validation_data=(X_test, Y_test),
              shuffle=True,
              callbacks=[tensorboard])

    使用tensorboard时,在终端输入

    tensorboard --logdir path_to_current_dir
    2.5 多个回调函数用逗号隔开
    from keras.callbacks import TensorBoard
    from keras.callbacks import EarlyStopping
    from keras.callbacks import ModelCheckpoint
    from keras.callbacks import ReduceLROnPlateau
    
    
    # callbacks:
    tb = TensorBoard(log_dir='./logs',  # log 目录
                     histogram_freq=1,  # 按照何等频率(epoch)来计算直方图,0为不计算
                     batch_size=32,     # 用多大量的数据计算直方图
                     write_graph=True,  # 是否存储网络结构图
                     write_grads=False, # 是否可视化梯度直方图
                     write_images=False,# 是否可视化参数
                     embeddings_freq=0,
                     embeddings_layer_names=None,
                     embeddings_metadata=None)
    
    es=EarlyStopping(monitor='val_loss', patience=20, verbose=0)
    
    mc=ModelCheckpoint(
        './logs/weight.hdf5',
        monitor='val_loss',
        verbose=0,
        save_best_only=True,
        save_weights_only=False,
        mode='auto',
        period=1
    )
    
    rp=ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,
        patience=20,
        verbose=0,
        mode='auto',
        epsilon=0.0001,
        cooldown=0,
        min_lr=0
    )
    
    callbacks = [es,tb,mc,rp]
    
    # start to train out model
    bs = 100
    ne = 1000
    hist = model.fit(data, labels_cat,batch_size=bs,epochs=ne,
                          verbose=2,validation_split=0.25,callbacks=callbacks)
    
    print("train process done!!")
    
     
    我的旨在学过的东西不再忘记(主要使用艾宾浩斯遗忘曲线算法及其它智能学习复习算法)的偏公益性质的完全免费的编程视频学习网站: fanrenyi.com;有各种前端、后端、算法、大数据、人工智能等课程。
    博主25岁,前端后端算法大数据人工智能都有兴趣。
    大家有啥都可以加博主联系方式(qq404006308,微信fan404006308)互相交流。工作、生活、心境,可以互相启迪。
    聊技术,交朋友,修心境,qq404006308,微信fan404006308
    26岁,真心找女朋友,非诚勿扰,微信fan404006308,qq404006308
    人工智能群:939687837

    作者相关推荐

  • 相关阅读:
    vue 父子组件通信props/emit
    mvvm
    Ajax
    闭包
    【CSS3】---only-child选择器+only-of-type选择器
    【CSS3】---last-of-type选择器+nth-last-of-type(n)选择器
    【CSS3】---first-of-type选择器+nth-of-type(n)选择器
    【CSS3】---结构性伪类选择器—nth-child(n)+nth-last-child(n)
    【CSS3】---结构性伪类选择器-first-child+last-child
    vue路由切换和用location切换url的区别
  • 原文地址:https://www.cnblogs.com/Renyi-Fan/p/13703325.html
Copyright © 2011-2022 走看看