zoukankan      html  css  js  c++  java
  • tensorflow中moving average的用法

    一般在保存模型参数的时候,都会保存一份moving average,是取了不同迭代次数模型的移动平均,移动平均后的模型往往在性能上会比最后一次迭代保存的模型要好一些。

    tensorflow-models项目中tutorials下cifar中相关的代码写的有点问题,在这写下我自己的做法:

    1.构建训练模型时,添加如下代码

    1 variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step)
    2 variables_averages_op = variable_averages.apply(tf.trainable_variables())
    3 ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()]
    4 train_op = tf.group(train_op, variables_averages_op)

    第1行创建了一个指数移动平均类 variable_averages

    第2行将variable_averages作用于当前模型中所有可训练的变量上,得到 variables_averages_op操作符

    第3行获得所有可训练变量对应的移动平均变量列表集合,后续用于保存模型

    第4行在原有的训练操作符基础上,再添加variables_averages_op操作符,后续session执行run的时候,除了训练时前向后向,梯度更新,还会对相应的变量做移动平均

    2.开始训练前,创建saver时,使用如下代码

    1 save_vars = tf.trainable_variables() + ave_vars
    2 saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)

    第1行获取所有需要保存的变量列表,这个时候 ave_vars就派上用场了。

    第2行创建saver,指定var_list为所有可训练变量及其对应的移动平均变量。

    另外需要注意的是,如果你的模型中有bn或者类似层,包含有统计参数(均值、方差等),这些不属于可训练参数,还需要额外添加进save_vars中,可以参考我的这篇博客

    3.在做inference的时候,利用如下代码从checkpoint中恢复出移动平均模型

    1 variable_averages = tf.train.ExponentialMovingAverage(0.999)
    2 variables_to_restore = variable_averages.variables_to_restore()
    3 saver = tf.train.Saver(variables_to_restore)
    4 saver.restore(sess, model_path)

    这几行很简单,就不做解释了。

    实际上,在inference的时候,刚刚的做法除了可以从checkpoint文件中恢复出移动平均参数,还可以恢复出对应迭代的模型参数,可以用来对比两种方式,哪种效果更好,这时只需要将上面代码的第3行改为saver = tf.train.Saver(tf.trainable_variables())即可(和保存时相同,如果有bn,也需要额外考虑)。在我的测试中,使用移动平均参数效果更佳。

  • 相关阅读:
    Linux 如何查看当前目录
    Docker快速入手实战笔记
    【ssh】ssh登录出现‘The authenticity of host ‘IP’ can't be established.’的问题
    【AFL(七)】afl-fuzz.c小改——输出文件夹暂存
    【steam】Steam背景美化——长展柜终极指南
    【AFL(六)】AFL源码中的那些头文件
    【AFL(五)】文件变异策略
    【Latex】详细的简易教程——写在论文开始之前
    【Latex】论文写作工具:VScode 2019 + latex workshop
    【AFL(四)】afl-cmin修改:文件夹相关操作鲁棒性
  • 原文地址:https://www.cnblogs.com/hrlnw/p/8067214.html
Copyright © 2011-2022 走看看