zoukankan      html  css  js  c++  java
  • day-20 tensorflow持久化之入门学习

           如果不对模型参数进行保存,当训练结束以后,模型也在内存中被释放,下一轮又需要对模型进行重新训练,有没有一种方法,可以利用之前已经训练好的模型参数值,直接进行模型推理或者继续训练?这里需要引入一个数据之久化的概念,其通用定义就是将内存中的数据模型转换为存储模型,以及将存储模型转换为内存中的数据模型的统称。

           OK,在tensorflow中,持久化可以是我们训练好的神经网络权重值和biase值写入到文件中,下一次直接从文件中进行读取,而不需要重新对模型进行训练。

            用tensorflow写一个简单的示例:求两个变量v1和v2的和,然后将其保存result变量中,然后将其保存到文件中,下一次训练时直接读取文件。

            先看保存程序:

    import tensorflow as tf
    
    # 定义两个变量,并对其进行求和
    v1 = tf.Variable(tf.constant(value=1.0,dtype=tf.float32,shape=[1],name="v1"))
    v2 = tf.Variable(tf.constant(value=2.0,dtype=tf.float32,shape=[1],name="v2"))
    result = v1 + v2
    
    # 将求和操作加到result集合中
    tf.add_to_collection('result',result)
    
    # 新建一个持久化对象
    saver = tf.train.Saver()
    
    # 运行会话,并持久化模型
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        # 如下操作执行完以后,会在sample_test目录下生成四个文件:
        # checkpoint:所有模型文件列表
        # model.data-00000-of-00001:
        # model-index:
        # model.meta:计算图的结构
        saver.save(sess=sess,save_path="sample_test/model")

    如果要重新加载模型,新的代码可以这么写:

    import tensorflow as tf
    
    # 从之前保存的点新建一个持久化对象
    saver = tf.train.import_meta_graph(meta_graph_or_file="sample_test/model.meta")
    
    with tf.Session() as sess:
        # 重新加载保存的参数值
        saver.restore(sess=sess,save_path="sample_test/model")
        # 注意get_collection返回一个列表,如果直接运行,结果也是一个List,注意比较下面的区别:
        print(tf.get_collection(key="result"))
        print(sess.run(tf.get_collection(key='result')))
        '''
        [<tf.Tensor 'add:0' shape=(1,) dtype=float32>]
        [array([3.], dtype=float32)]
        '''
    
        print(tf.get_collection(key="result")[0])
        print(sess.run(tf.get_collection(key='result')[0]))
        '''
        tf.Tensor 'add:0' shape=(1,)
        [3.]
        '''

    进一步,如果我们的网络结构加入了滑动平均模型,重新加载模型时,我们往往是希望用其进行验证,需要使用滑动平均模型参数的值,一个完整的示例如下:

    训练时:

    # 导入库
    import tensorflow as tf
    
    # 定义一个变量
    v = tf.Variable(initial_value=0,dtype=tf.float32,name='v')
    
    # 显示当前有哪些变量
    # <tf.Variable 'v:0' shape=() dtype=float32_ref>
    for variable in tf.global_variables():
        print(variable)
    
    # 定义一个滑动平均模型,和变量应用模型的操作
    ema = tf.train.ExponentialMovingAverage(0.999)
    maintain_average_op = ema.apply(tf.global_variables())
    
    # 显示当前有哪些变量
    # <tf.Variable 'v:0' shape=() dtype=float32_ref>
    # <tf.Variable 'v/ExponentialMovingAverage:0' shape=() dtype=float32_ref>
    for variable in tf.global_variables():
        print(variable)
    
    saver = tf.train.Saver()
    
    # 执行会话,并持久化
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        sess.run(tf.assign(ref=v,value=10))
        sess.run(maintain_average_op)
        saver.save(sess=sess,save_path="sample_test/model")
        print(sess.run([v,ema.average(v)]))

    重新加载时:

    import tensorflow as tf
    
    v = tf.Variable(0,dtype=tf.float32,name='v')
    
    ema = tf.train.ExponentialMovingAverage(0.999)
    saver = tf.train.Saver(ema.variables_to_restore())
    
    # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
    print(ema.variables_to_restore())
    
    with tf.Session() as sess:
        saver.restore(sess,save_path="sample_test/model")
        # 自动加载滑动平均值来代替变量的值
        # 0.009999871
        print(sess.run(v))
  • 相关阅读:
    call和apply的区别
    淘宝镜像(cnpm)的安装和使用
    文件包含漏洞
    vue简单的日历
    微信小程序(mpvue)—解决视频播放bug的一种方式
    vue 异步组件
    vuex的学习笔记
    vue2.0 添加监听滚动事件
    jquery tmpl生成导航
    vue 控制视图
  • 原文地址:https://www.cnblogs.com/python-frog/p/9478808.html
Copyright © 2011-2022 走看看