zoukankan      html  css  js  c++  java
  • 模型文件(checkpoint)对模型参数的储存与恢复

    1.  模型参数的保存:

    import tensorflow as tf
    w=tf.Variable(0.0,name='graph_w')
    ww=tf.Variable(tf.random_normal(shape=(2,3),stddev=0.5),name='graph_ww')
    # double=tf.multiply(2.0,w)
    saver=tf.train.Saver({'weights_w':w,'weights_ww':ww}) # 此处模型文件关键字可以自己命名,如weights_w与weights_ww
    # 关键字所对应的值名字为变量w与ww,而不是graph_w与graph_ww,否则会报错。{'weights_w':w,'weights_ww':ww}为模型文件
    # 需要保存的变量,用字典形式书写出来,若无此字典,默认保存全部。
    sess=tf.Session()
    sess.run(tf.global_variables_initializer())
    for i in range(4):
    d=sess.run(tf.assign_add(w,2)) # 这一步对w进行计算,得到最后值为8,最终将其保存saver种
    # 其中 w 必须为变量名为w,不能是graph中的graph_w,否则会报错
    print(d)
    print('w=',sess.run(w))
    print('ww=',sess.run(ww))
    saver.save(sess,'test.ckpt')




    2. 模型参数的恢复:

    import tensorflow as tf
    restore_w=tf.Variable(0.0,name='weights_w')
    restore_ww=tf.Variable(tf.random_normal(shape=(2,3),stddev=0.5),name='weights_ww') # 尽管有初始值,但未调用tf.global_variables_initializer()此函数,则不会将其初始值赋值给该变量
    # restore_w与restore_ww分别对应保存变量的w与ww的恢复,若想恢复则graph必须是对应保存变量名对应字典的关键字,
    # 否则将会报错。即恢复对应变量参数的变量名字可以自己重新命名,但graph中的名字必须是字典关键字。
    double=tf.multiply(2.0,restore_w)
    saver=tf.train.Saver()
    sess=tf.Session()
    saver.restore(sess,'test.ckpt')
    f=sess.run(double)
    print(f)
    print('restore_ww=',sess.run(restore_ww))

    总结:
      ① 变量初始化有2种方法,若不调用tf.global_variables_initializer()或tf.variables_initializer()函数就不会将变量restore_ww=tf.Variable(tf.random_normal(shape=(2,3),stddev=0.5),name='weights_ww')
        初始化。相反,使用saver.restore()将其变量初始化了。同时,也说明变量初始化才不会报错。
      ② 模型参数以字典形式保存,其key可自己命名,其value必须为变量(非graph的name)。
      ③ 模型参数恢复对应变量可以自己命名,但对应变量中graph的name必须是保存模型参数对应变量的关键字(key)。



  • 相关阅读:
    Win8常用快捷键
    清除远程桌面连接记录
    通过注册表改变“我的文档”等的默认位置,防止系统重装造成数据丢失
    HTML 转义字符对照表
    r语言 函数
    sparkr——报错
    R语言--saprkR基本使用
    r绘图基本
    R绘制中国地图,并展示流行病学数据
    r画饼图
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/11682065.html
Copyright © 2011-2022 走看看