zoukankan      html  css  js  c++  java
  • Keras/Tensorflow训练逻辑研究

    Keras是什么,以及相关的基础知识,这里就不做详细介绍,请参考Keras学习站点http://keras-cn.readthedocs.io/en/latest/

    Tensorflow作为backend时的训练逻辑梳理,主要是结合项目,研究了下源代码!

    我们的项目是智能问答机器人,基于双向RNN(准确的说是GRU)网络,这里网络结构,就不做介绍,只研究其中的训练逻辑,我们的训练是基于fit_generator,即基于生成器模型,节省内存,有助效率提升。

    什么是生成器以及生成器的工作原理,这里不表,属于python的基础范畴。

    1. Keras的训练,是基于batch进行的,每一个batch训练过程,进行一次loss和acc的调整

    1.1 .主要核心代码

    A. /home/anaconda2/lib/python2.7/site-packages/keras/legacy/interfaces.py

    1)里面的装饰器函数generate_legacy_interface里面。这里涉及到fit_generator这个最为核心的入口函数的执行过程。

    2)python里面装饰器工作原理,非常类似java代码里面的AOP切面编程逻辑,即在正常的业务逻辑执行前,将before或者after或者两者都执行一下。

    3)训练函数原型及重要参数解释

    def fit_generator(self, generator,        #生成器,一个yield的函数,迭代返回数据
                 steps_per_epoch,             #一次训练周期(具体epoch是什么含义,要理解清楚)里面进行多少次batch
                 epochs=1,                    #设置进行几次全数据集的训练,每一次全数据集训练过程被定义成一个epoch,其实这个是可以灵活应用的
                 verbose=1,                   #一个开关,打开时,打印清晰的训练数据,即加载ProgbarLogger这个回调函数
                 callbacks=None,              #设置业务需要的回调函数,我们的模型中添加了ModelCheckpoint这个回调函数
                 validation_data=None,        #验证用的数据源设置,evaluate_generator函数要用到这个数据源,我们的项目里面,这里也是一个生成器
                 validation_steps=None,       #设置验证多少次数据后取平均值作为此epoch训练后的效果,val_loss,val_acc的值受这个参数直接影响
                 class_weight=None,           #此参数以及后续参数,我们的项目采用的都是默认值,可以参考官方文档了解细节
                 max_queue_size=10,
                 workers=1,
                 use_multiprocessing=False,
                 initial_epoch=0)

    B. /home/anaconda2/lib/python2.7/site-packages/keras/callbacks.py

    1)这里重点有ModelCheckpoint这个回调函数,涉及到业务参数,其他回调都是keras框架默认行为。

    2)callback这个类,其实是一个容器,具体表现为一个List,可以在git_generator运行时,基于该函数的入参,构建一个Callback的实例,即一个list里面装入业务需要的callback实例,这里默认会有BaseLogger以及History这个callback,然后会判断verbose为true时,会添加ProgbarLogger这个callback,除此之外,就是fit_generator函数入参callbacks传入的参数。一般都会传递ModelCheckpoint这个。

    3)在git_generator这个基于生成器模式训练的过程中,每一个epoch结束(on_epoch_end)时,都要调用这个callback函数(ModelCheckpoint)进行模型数据写文件的操作

    2. Keras训练时用到的几个重要回调函数(主要工作在on_batch_end里面)

    回调函数是基于抽象类Callback实现的。下面是Callback的成员函数,便于理解。

       def __init__(self):
            self.validation_data = None
    
        def set_params(self, params):
            self.params = params
    
        def set_model(self, model):
            self.model = model
    
        def on_epoch_begin(self, epoch, logs=None):
            pass
    
        def on_epoch_end(self, epoch, logs=None):
            pass
    
        def on_batch_begin(self, batch, logs=None):
            pass
    
        def on_batch_end(self, batch, logs=None):
            pass
    
        def on_train_begin(self, logs=None):
            pass
    
        def on_train_end(self, logs=None):
            pass

    A. keras.callbacks.BaseLogger

    统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。

    def on_batch_end(self, batch, logs=None):
            logs = logs or {}
            batch_size = logs.get('size', 0)
            self.seen += batch_size
    
            for k, v in logs.items():
                if k in self.totals:
                    self.totals[k] += v * batch_size
                else:
                    self.totals[k] = v * batch_size

    在BaseLogger这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。

    def on_epoch_end(self, epoch, logs=None):
            if logs is not None:
                for k in self.params['metrics']:
                    if k in self.totals:
                        # Make value available to next callbacks.
                        logs[k] = self.totals[k] / self.seen

    B. keras.callbacks.ModelCheckpoint

    在on_epoch_end时会保存模型数据进入文件

    def on_epoch_end(self, epoch, logs=None):
            logs = logs or {}
            self.epochs_since_last_save += 1
            if self.epochs_since_last_save >= self.period:
                self.epochs_since_last_save = 0
                filepath = self.filepath.format(epoch=epoch, **logs)
                if self.save_best_only:
                    current = logs.get(self.monitor)
                    if current is None:
                        warnings.warn('Can save best model only with %s available, '
                                      'skipping.' % (self.monitor), RuntimeWarning)
                    else:
                        if self.monitor_op(current, self.best):
                            if self.verbose > 0:
                                print('Epoch %05d: %s improved from %0.5f to %0.5f,'
                                      ' saving model to %s'
                                      % (epoch, self.monitor, self.best,
                                         current, filepath))
                            self.best = current
                            if self.save_weights_only:
                                self.model.save_weights(filepath, overwrite=True)
                            else:
                                self.model.save(filepath, overwrite=True)
                        else:
                            if self.verbose > 0:
                                print('Epoch %05d: %s did not improve' %
                                      (epoch, self.monitor))
                else:
                    if self.verbose > 0:
                        print('Epoch %05d: saving model to %s' % (epoch, filepath))
                    if self.save_weights_only:
                        self.model.save_weights(filepath, overwrite=True)
                    else:
                        self.model.save(filepath, overwrite=True)

    C.keras.callbacks.History

    主要记录每一次epoch训练的结果,结果包含loss以及acc的值

    D. keras.callbacks.ProgbarLogger

    这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。

    3. 具体训练逻辑过程

    A. 训练函数分析

    a. model.fit_generator 训练入口函数(参考上面的函数原型定义), 我们项目中用tk_data_generator函数作为训练数据提供者(生成器)
    1) callbacks.on_train_begin()
    2) while epoch < epochs:
    3)         callbacks.on_epoch_begin(epoch)
    4)         while steps_done < steps_per_epoch:
    5)             generator_output = next(output_generator)       #生成器next函数取输入数据进行训练,每次取一个batch大小的量
    6)             callbacks.on_batch_begin(batch_index, batch_logs)
    7)             outs = self.train_on_batch(x, y,sample_weight=sample_weight,class_weight=class_weight)
    8)             callbacks.on_batch_end(batch_index, batch_logs)
                end of while steps_done < steps_per_epoch
                self.evaluate_generator(...)          #当一个epoch的最后一次batch执行完毕,执行一次训练效果的评估
    9)      callbacks.on_epoch_end(epoch, epoch_logs)          #在这个执行过程中实现模型数据的保存操作
          end of while epoch < epochs
    10) callbacks.on_train_end()


    b. 特别介绍下train_on_batch
       train_on_batch (keras中的trainning.py)
            |_self._standardize_user_data
            |_self._make_train_function
            |_self.train_function (tensorflow的函数)
                            |_updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict,**self.session_kwargs)

    B训练和验证的对比

    a. 在每一个epoch的最后一个迭代(最后一次batch)时,要进行此轮epoch的校验(evaluate)

    日志如下:

    141/141 [==============================] - 12228s - loss: 0.5715 - acc: 0.6960 - val_loss: 0.5082 - val_acc: 0.7450
    
    
    第一个141表示batch_index已经达到141,即steps_per_epoch参数规定的最后一步
    第二个141表示steps_per_epoch,即一个epoch里面进行多少次batch处理
    12228s 表示此batch处理结束所花费的时间
    loss:此epoch里面的平均损失值
    acc:此epoch里面的平均准确率   
    val_loss:此epoch训练完后进行的evaluate得到的损失值
    val_acc:此epoch训练完后进行的evaluate得到的正确率

    b. 验证逻辑,和训练逻辑差不多,只是将validation_steps指定次数的test的值进行取平均值,得到validation_steps次test的均值作为本epoch训练的最终效果

    self.evaluate_generator(validation_data,validation_steps,max_queue_size=max_queue_size,workers=workers,use_multiprocessing=use_multiprocessing)

    1) while steps_done < steps:
    2)           generator_output = next(output_generator)
    3)         outs = self.test_on_batch(x, y, sample_weight=sample_weight)
    4)对上述while得到的每次outs进行 averages.append(np.average([out[i] for out in all_outs],weights=batch_sizes))

    其中重点test_on_batch

    test_on_batch(self, x, y, sample_weight=None)
             |_self._standardize_user_data(x, y,sample_weight=sample_weight,check_batch_axis=True)
             |_self._make_test_function()
             |_self.test_function(ins)                    
                        |_updated = session.run(self.outputs + [self.updates_op],feed_dict=feed_dict,**self.session_kwargs)

    c. train和test的重要区别,应该体现在下面的两个函数上

    def _make_train_function(self):
            if not hasattr(self, 'train_function'):
                raise RuntimeError('You must compile your model before using it.')
            if self.train_function is None:
                inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
                if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                    inputs += [K.learning_phase()]
    
                with K.name_scope('training'):
                    with K.name_scope(self.optimizer.__class__.__name__):
                        training_updates = self.optimizer.get_updates(
                            params=self._collected_trainable_weights,
                            loss=self.total_loss)
                    updates = self.updates + training_updates
                    # Gets loss and metrics. Updates weights at each call.
                    self.train_function = K.function(inputs,
                                                     [self.total_loss] + self.metrics_tensors,
                                                     updates=updates,
                                                     name='train_function',
                                                     **self._function_kwargs)
    def _make_test_function(self):
            if not hasattr(self, 'test_function'):
                raise RuntimeError('You must compile your model before using it.')
            if self.test_function is None:
                inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights
                if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
                    inputs += [K.learning_phase()]
                # Return loss and metrics, no gradient updates.
                # Does update the network states.
                self.test_function = K.function(inputs,
                                                [self.total_loss] + self.metrics_tensors,
                                                updates=self.state_updates,
                                                name='test_function',
                                                **self._function_kwargs)

    经过前面的代码逻辑梳理,可以看到不管是train的过程还是test的过程,最终底层都是调用Tensorflow的session.run方法进行loss和acc的获取,细心的观察,会发现两个session.run函数的入参其实有点不同。

    结合上面train和test的私有函数中标注红色的注释,以及用K.function生成函数的入参中,可以看出train和test的差异。

    总结:

    0. 训练过程中,每次权重的更新都是在一个batch上进行一次,是基于batch量的数据为单位进行一次权重的更新

    1. 基于生成器模型训练数据,可以提升效率,降低对物理服务器性能,尤其是内存的要求

    2. 训练过程中,Callback函数执行了大量的工作,包括loss、acc值的记录,以及训练中间结果的日志反馈,最重要的是模型数据的输出,也是通过callback的方式实现(ModelCheckpoint)

    3. 训练(train)和验证(evaluate/validate)的逻辑近乎一样,训练要更新权重,但是验证过程,仅仅更新网络状态,不涉及权重(loss以及acc参数)信息的更新

    4. 代码梳理过程中,得出结论,Keras对python编程基本功底要求还是有点高的,采用了推导式编程习惯,生成器,装饰器,回调等编程思想,另外,对矩阵运算,例如numpy.dot以及numpy.multiply的数学逻辑都有一定要求,否则比较难看懂。

  • 相关阅读:
    PCA,到底在做什么
    论文笔记:Deep feature learning with relative distance comparison for person re-identification
    论文笔记:Cross-Domain Visual Matching via Generalized Similarity Measure and Feature Learning
    word2vec概述
    登录获取token,token参数关联至所有请求的请求体内
    pip安装库时报错,使用国内镜像加速
    python+unittest+requests+HTMLRunner编写接口自动化测试集
    python实现http get请求
    python实现以application/json格式为请求体的http post请求
    反编译apk
  • 原文地址:https://www.cnblogs.com/shihuc/p/8485651.html
Copyright © 2011-2022 走看看