zoukankan      html  css  js  c++  java
  • Tensorflow模型保存与加载

       在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来。

       TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,如:saver.save(sess, "/Model/model"), 执行完,在相应的目录下将会有4个文件:

         meta:文件保存的是图结构信息,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。

        ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同。是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之 前,保存在.ckpt文件中。0.11后,通过两个文件保存,如:.data-00000-of-00001和.index文件

         checkpoint文件:checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model。加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

        保存模型时,只会保存变量的值,placeholder里面的值不会被保存。

      关于save()方法的参数记录:

        •  sess:在tensorflow中,变量是存在于Session环境中,即只有在Session环境下才会存有变量值,因此,保存模型时需要传入session
        • global_step:在n次迭代后,再保存模型,只需设置global_step参数即可
        • 由于图是不变的,没必要每次都去保存,可以在多次迭代过程中只用保存一次模型即可,可以通过设置write_meta_graph=False即可
        • keep_checkpoint_every_n_hours:用来设置间隔时间来保存
        • max_to_keep: 用来设置保存最近模型文件的个数
        • 如果不想保存所有变量,而只保存一部分变量,可以通过指定variables/collections,默认是保存所有的变量。

        tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名}。

     

      导入模型

        加载图:saver=tf.train.import_meta_graph(.meta文件)即可。

        加载模型参数:aver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))

    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict = {w1: 13.0, w2: 17.0}
    注意w1:0是tensor的name,既可以指定变量名称,也可以指定操作名称。

      其实,我们也可以只恢复图的一部分,并且再加入其它的op用于fine-tuning。只需通过graph.get_tensor_by_name()方法获取需要的op,并且在此基础上建立图即可。例如:假设我们想使用已经训练好的VGG模型,并且要更改部分层,如下:

    saver = tf.train.import_meta_graph('vgg.meta')
    # 访问图
    graph = tf.get_default_graph() 
    
    #访问用于fine-tuning的output
    fc7= graph.get_tensor_by_name('fc7:0')
    
    #如果你想修改最后一层梯度,需要如下
    fc7 = tf.stop_gradient(fc7) # It's an identity function
    fc7_shape= fc7.get_shape().as_list()
    
    new_outputs=2
    weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
    biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
    output = tf.matmul(fc7, weights) + biases
    pred = tf.nn.softmax(output)

     

  • 相关阅读:
    ubuntu>雷鸟只能收邮件不能发邮件
    ubuntu>安装jdk(转)
    ubuntu>修改root密码
    ios>Could not instantiate class named NSLayoutConstraint(转)
    ios>xcode4.5 如何找到以前的iphone模拟器(转)
    Windows7系统开始菜单改成经典样式
    ASP.NET多语言版的开发
    Dynamic repositories in LightSpeed
    Enhancing queries in dynamic repositories
    C# 4.0 Dynamic关键字全解析
  • 原文地址:https://www.cnblogs.com/czx1/p/9557613.html
Copyright © 2011-2022 走看看