zoukankan      html  css  js  c++  java
  • tensorflow variable的保存和修改(加载一部分variable到新的model中)

    link: https://www.tensorflow.org/guide/saved_model

    中文博客:https://blog.csdn.net/Searching_Bird/article/details/78274207

      https://blog.csdn.net/mieleizhi0522/article/details/80535189 

    self.saver = tf.train.Saver({'words/_word_embeddings':self._word_embeddings})

    for x in tf.all_variables():
    print(x.name)

    mysaver = tf.train.import_meta_graph(self.config.dir_model_storepath_pre_graph)
    mysaver.restore(self.sess, tf.train.latest_checkpoint(self.config.dir_model_storepath_pre))

    一,恢复部分预训练模型的参数。

    weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
    saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
    saver.restore(sess, model_filename)
    二,手动初始化剩下的(预训练模型中没有的)参数。

    var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())
     

    保存的时候怎么保存呢?我想保存全部变量,所以要重新写一个对象,名字和恢复的那个saver对象不同:

    saver_out=tf.train.Saver()
    saver_out.save(sess,'file_name')
    这个时候就保存了全部变量,如果你想保存部分变量,只需要在构造器里传入想要保存的变量的名字就行了。

  • 相关阅读:
    etcd的原理分析
    (转)Linux sort命令
    随机森林
    python 类的定义和继承
    python random
    Spark源码阅读(1): Stage划分
    Mac 上安装MySQL
    Python 删除 数组
    在循环中将多列数组组合成大数组
    准确率 召回率
  • 原文地址:https://www.cnblogs.com/wuxiangli/p/10317128.html
Copyright © 2011-2022 走看看