zoukankan      html  css  js  c++  java
  • TF的模型文件

    TF的模型文件

    标签(空格分隔): TensorFlow


    Saver

    tensorflow模型保存函数为:

    tf.train.Saver()
    

    当然,除了上面最简单的保存方式,也可以指定保存的步数,多长时间保存一次,磁盘上最多保有几个模型(将前面的删除以保持固定个数),如下:

    创建saver时指定参数:

    saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)
    

    其中:

    • savable_variables指定待保存的变量,比如指定为tf.global_variables()保存所有global变量;指定为[v1, v2]保存v1和v2两个变量;如果省略,则保存所有;
    • max_to_keep指定磁盘上最多保有几个模型;
    • keep_checkpoint_every_n_hours指定多少小时保存一次。

    保存模型时指定参数:

    saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)
    

    如上,其中可以指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph等等。

    示例:

    import tensorflow as tf
    ​
    v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
    v2= tf.Variable(tf.zeros([200]), name="v2")
    v3= tf.Variable(tf.zeros([100]), name="v3")
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        saver.save(sess,"checkpoint/model.ckpt",global_step=1)
    

    运行后,保存模型保存,得到四个文件:

    • checkpoint
    • model.ckpt-1.data-00000-of-00001
    • model.ckpt-1.index
    • model.ckpt-1.meta

    checkpoint中记录了已存储(部分)和最近存储的模型:

    model_checkpoint_path: "model.ckpt-1"
    all_model_checkpoint_paths: "model.ckpt-1"
    ...
    

    meta file保存了graph结构,包括 GraphDef,SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值。

    index file为一个string-string table,table的key值为tensor名,value为serialized BundleEntryProto。每个BundleEntryProto表述了tensor的metadata,比如那个data文件包含tensor、文件中的偏移量、一些辅助数据等。

    data file保存了模型的所有变量的值,TensorBundle集合。

    Restore

    Restore模型的过程可以分为两个部分,首先是创建模型,可以手动创建,也可以从meta文件里加载graph进行创建。

    模型加载为:

    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('/xx/model.ckpt.meta')
        saver.restore(sess, "/xx/model.ckpt")
    

    .meta文件中保存了图的结构信息,因此需要在导入checkpoint之前导入它。否则,程序不知道checkpoint中的变量对应的变量。另外也可以:

    # Recreate the EXACT SAME variables
    v1 = tf.Variable(..., name="v1")
    v2 = tf.Variable(..., name="v2")
    ...
    
    # Now load the checkpoint variable values
    with tf.Session() as sess:
        saver = tf.train.Saver()
        saver.restore(sess, "/xx/model.ckpt")
        #saver.restore(sess, tf.train.latest_checkpoint('./'))
    

    PS:不存在model.ckpt文件,saver.py中:Users only need to interact with the user-specified prefix... instead of any physical pathname.

    当然,还有一点需要注意,并非所有的TensorFlow模型都能将graph输出到meta文件中或者从meta文件中加载进来,如果模型有部分不能序列化的部分,则此种方法可能会无效。

    使用Restore的模型

    查看模型的参数

    with tf.Session() as sess:
      saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
      saver.restore(sess, tf.train.latest_checkpoint('./'))
      tvs = [v for v in tf.trainable_variables()]
      for v in tvs:
        print(v.name)
        print(sess.run(v))
    

    如名所言,以上是查看模型中的trainable variables;或者我们也可以查看模型中的所有tensor或者operations,如下:

    with tf.Session() as sess:
      saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
      saver.restore(sess, tf.train.latest_checkpoint('./'))
      gv = [v for v in tf.global_variables()]
      for v in gv:
        print(v.name)
    

    上面通过global_variables()获得的与前trainable_variables类似,只是多了一些非trainable的变量,比如定义时指定为trainable=False的变量,或Optimizer相关的变量。

    下面则可以获得几乎所有的operations相关的tensor:

    with tf.Session() as sess:
      saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
      saver.restore(sess, tf.train.latest_checkpoint('./'))
      ops = [o for o in sess.graph.get_operations()]
      for o in ops:
        print(o.name)
    

    首先,上面的sess.graph.get_operations()可以换为tf.get_default_graph().get_operations(),二者区别无非是graph明确的时候可以直接使用前者,否则需要使用后者。

    此种方法获得的tensor比较齐全,可以从中一窥模型全貌。不过,最方便的方法还是推荐使用tensorboard来查看,当然这需要你提前将sess.graph输出。

    直接使用原始模型进行训练或测试

    这种操作比较简单,无非是找到原始模型的输入、输出即可。
    只要搞清楚输入输出的tensor名字,即可直接使用TensorFlow中graph的get_tensor_by_name函数,建立输入输出的tensor:

    with tf.get_default_graph() as graph:
      data = graph.get_tensor_by_name('data:0')
      output = graph.get_tensor_by_name('output:0')
    

    从模型中找到了输入输出之后,即可直接使用其继续train整个模型,或者将输入数据feed到模型里,并前传得到test输出了。

    需要说明的是,有时候从一个graph里找到输入和输出tensor的名字并不容易,所以,在定义graph时,最好能给相应的tensor取上一个明显的名字,比如:

    data = tf.placeholder(tf.float32, shape=shape, name='input_data')
    preds = tf.nn.softmax(logits, name='output')
    

    诸如此类。这样,就可以直接使用tf.get_tensor_by_name(‘input_data:0’)之类的来找到输入输出了。

    扩展原始模型

    除了直接使用原始模型,还可以在原始模型上进行扩展,比如对1中的output继续进行处理,添加新的操作,可以完成对原始模型的扩展,如:

    with tf.get_default_graph() as graph:
      data = graph.get_tensor_by_name('data:0')
      output = graph.get_tensor_by_name('output:0')
      logits = tf.nn.softmax(output)
    

    使用原始模型的某部分

    有时候,我们有对某模型的一部分进行fine-tune的需求,比如使用一个VGG的前面提取特征的部分,而微调其全连层,或者将其全连层更换为使用convolution来完成,等等。TensorFlow也提供了这种支持,可以使用TensorFlow的stop_gradient函数,将模型的一部分进行冻结。

    with tf.get_default_graph() as graph:
      graph.get_tensor_by_name('fc1:0')
      fc1 = tf.stop_gradient(fc1)
      # add new procedure on fc1
    
  • 相关阅读:
    CSS 选择器之复合选择器
    答辩ppt
    开题报告
    ADS1110/ADS1271
    电感、磁珠和零欧电阻的区别
    ROM、RAM、DRAM、SRAM和FLASH区别
    运放的带宽
    ADC 分辨率和精度的区别
    Verilog
    C语言 文件读取
  • 原文地址:https://www.cnblogs.com/houkai/p/9723988.html
Copyright © 2011-2022 走看看