zoukankan      html  css  js  c++  java
  • 【TF-2-4】Tensorflow-模型和数据的保存和载入

    目录

    1. 基本方法
    2. 不需重新定义网络结构的方法
    3. saved_model方式

    附件一:sklearn上的用法

    一、基本方法

    1.1 保存

    • 定义变量
    • 使用saver.save()方法保存
    import tensorflow as tf
    import numpy as np
    W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
    b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')
    
    init = tf.initialize_all_variables()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init)
        save_path = saver.save(sess,"save/model.ckpt")

    1.2 载入

    • 定义变量
    • 使用saver.restore()方法载入
    import tensorflow as tf
    import numpy as np
    W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
    b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
    
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,"save/model.ckpt")

    1.3 说明

    1)创建saver时,可以指定需要存储的tensor,如果没有指定,则全部保存;

    2)默认情况下:saver.save(sess,"save/model.ckpt")产生4个文件:

    checkpoint文件保存最新的模型;

    model.ckpt.data 以字典的形式保存权重偏置项等训练参数

    model.ckpt.index:存储训练好的参数索引

    model.ckpt.meta : 元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,计算图的网络结构信息,完整的graph、variables、operation、collection。

    3)如何知道tensor的名字,最好是定义tensor的时候就指定名字,如上面代码中的name='w',如果你没有定义name,tensorflow也会设置name,只不过这个name就是根据你的tensor或者操作的性质。所以最好还是自己定义好name。

       

    【说明:这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。】

    二、不需重新定义网络结构的方法

    tf.train.import_meta_graph

    import_meta_graph(

    meta_graph_or_file,

    clear_devices=False,

    import_scope=None,

    **kwargs

    )

    这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

    比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下:

    2.1 保存

    和1.1一样,保持不变

    2.2 载入

    import tensorflow as tf
    import numpy as np
    # W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
    # b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
    
    # saver = tf.train.Saver()
    with tf.Session() as sess:
        new_saver = tf.train.import_meta_graph("save/model.ckpt.meta")
        new_saver.restore(sess, "save/model.ckpt")

    【个人理解:model.ckpt.meta : 保存了TensorFlow计算图的网络结构信息,import_meta_graph("save/model.ckpt.meta")这句拉取了结构,故不用重新定义。】

    三、saved_model方式

    实现了 (y = x + b)当输入一个x 那么输出的结果y就等于输入x加上b。

    3.1 保存

    # Author:yifan
    import os
    import tensorflow as tf # 以下所有代码默认导入
    from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
    # 保存模型路径
    PATH = './models'
    # 创建一个变量
    one = tf.Variable(3.0)
    # 创建一个占位符,在 Tensorflow 中需要定义 placeholder 的 type ,一般为 float32 形式
    num = tf.placeholder(tf.float32,name='input')
    # 创建一个加法步骤,注意这里并没有直接计算
    sum = tf.add(num,one,name='output')
    # 初始化变量,如果定义Variable就必须初始化
    init = tf.global_variables_initializer()
    # 创建会话sess
    with tf.Session() as sess:
    	sess.run(init)
    	print(sess.run(sum, feed_dict={num: 5.0}))
    	# #保存SavedModel模型,使用以下三句
    	builder = tf.saved_model.builder.SavedModelBuilder(PATH)
    	signature = predict_signature_def(inputs={'input':num}, outputs={'output':sum})
    	builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING],signature_def_map={'predict': signature})
    	builder.save()

    说明:

    1. tf.saved_model.builder.SavedModelBuilder:该方法的参数是传入用于保存模型的目录名,目录不用预先创建
    2. predict_signature_def:将输入节点、输出节点和名字(sig_name)传入,生成一个签名对象。传入的参数为输入和输出以及他们的name。
    3. add_meta_graph_and_variables:将签名加入到模型中
    4. 第一个参数传入的是Session它包含了当前graph(图)和Variables(变量)。
    5. 第二个参数是给当前需要保存的MetaGraph 一个标签,标签名可以自定义,在之后载入模型的时候,需要根据这个标签名去查找对应的MetaGraphDef,找不到就会报如 RuntimeError: MetaGraphDef associated with tags 'foo' could not be found in SavedModel这样的错。

    ---- 标签也可以选用系统定义好的参数,tf.saved_model.tag_constants.SERVING与        tf.saved_model.tag_constants.TRAINING等。

    运行结果:8.0,和保存的模型:

    1. 执行完成后会在当前项目的目录下生成models文件夹,里面包含variables文件夹以及saved_model.pb文件。
    2. variables保存所有变量信息,
    3. saved_model.pb用于保存模型结构等信息,含图形结构。

    注意:当前目录下不可以存在models文件夹,否则会报错。

    3.2 载入

    # Author:yifan
    import tensorflow as tf # 以下所有代码默认导入
    PATH = './models'
    with tf.Session() as sess:
      tf.saved_model.loader.load(sess, ["serve"], PATH) 
    #一种载入变量的方式:
      in_x =tf.saved_model.loader.load(sess, ["serve"], PATH).signature_def['predict'].inputs['input'].name
    #另一种载入变量的方式:
    # in_x = sess.graph.get_tensor_by_name('input:0')     #加载输入变量
      y = sess.graph.get_tensor_by_name('output:0')       #加载输出变量
      scores = sess.run(y, feed_dict={in_x: 3.})
      print(scores)

    说明:

    1. tf.saved_model.loader.load方法加载模型,第二个参数["serve"]为TAG标签与存模型时候指定的字段相同(tf.saved_model.tag_constants.SERVING = "serve",本文中调用了tf的定义),第三个参数为模型路径;
    2. tf.saved_model.loader.load(sess, ["serve"], PATH).signature_def['predict'].inputs['input'].name:用signature_def方法从导入的模型中提取签名。和3)作用是一样的。
    3. sess.graph.get_tensor_by_name:加载输入输出变量,注意这里的变量name都需要加上":0",如"input"变为"input:0"
    4. 最后像之前那样sess.run(),feed喂入数据,这里输入了个3.0。

    结果:6.0

    3.3 查看模型的Signature签名

    传统的导入 需要用get_tensor_by_name , 这样就需要记录tensor的name熟悉,很麻烦。通过signature,我们可以指定变量的别名,方便存取。但如果我们拿到了别人的含有signature一个SavedModel模型而且并不知道"标签"那么怎么调用呢?

    ---Tensorflow官方已经为我们准备好了一个脚本,tensorflow下的saved_model_cli.py文件可以帮到。

    我们可以'WIN+R'输入'cmd'然后回车打开你的CMD,然后指定路径到你的模型目录下,运行:

    saved_model_cli show --dir=./ --all

    打印出的信息中我们就可以看到模型的输入/输出的名称、数据类型、shape以及方法名称。

    附件一:sklearn上的用法

    保存参数:

    from sklearn.externals import joblib

    joblib.dump((centres, des_list,img_features), "imgs_features.pkl", compress=3)

    读取参数:

    centres, des_list, img_features = joblib.load("imgs_features.pkl") #读取保存的特征

    参考文章

    【1】 https://blog.csdn.net/thriving_fcl/article/details/71423039

    【2】 https://blog.csdn.net/liuxiao214/article/details/79048136

    【3】 https://blog.csdn.net/thriving_fcl/article/details/75213361

    【4】 https://blog.csdn.net/weixin_43215867/article/details/85038313

  • 相关阅读:
    并查集图冲突hdu1272
    CentOS 7通过yum安装fcitx五笔输入法
    近期的技术问题让云供应商进行预设加密
    POJ 1166 The Clocks (暴搜)
    windows中的mysql修改管理员密码
    Visio画UML类图、序列图 for Java
    js中的时间与毫秒数互相转换
    java.lang.OutOfMemoryError: unable to create new native thread 居然是MQ问题
    WEB移动应用框架构想(转载)
    Android SDK安装教程
  • 原文地址:https://www.cnblogs.com/yifanrensheng/p/13191300.html
Copyright © 2011-2022 走看看