zoukankan      html  css  js  c++  java
  • Tensorflow模型的 暂存 恢复 微调 保存 加载

    • 暂存模型(*.index为参数名称,*.meta为模型图,*.data*为参数)
    tf.reset_default_graph()
    
    weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
    biases = tf.Variable(0, name="biases")
    
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    print(sess.run([weights]))
    saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
    
    sess.close()
    • 暂存模型(同一模型多次保存可以不保存模型图节省时间)
    tf.reset_default_graph()
    
    weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
    biases = tf.Variable(0, name="biases")
    
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    print(sess.run([weights]))
    saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
    time.sleep(5)
    saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL1_NAME), write_meta_graph=False)
    time.sleep(5)
    saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL2_NAME), write_meta_graph=False)
    
    sess.close()
    • 恢复模型(手动生成网络则不需要*.meta文件)
    tf.reset_default_graph()
    
    weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
    biases = tf.Variable(0, name="biases")
    
    saver = tf.train.Saver()
    sess = tf.Session()
    saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
    
    print(sess.run([weights]))
    
    sess.close()
    • 恢复模型(从*.meta文件生成网络)
    tf.reset_default_graph()
    
    saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
    sess = tf.Session()
    saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
    
    all_op = tf.get_default_graph().get_operations() #获取所有op
    all_var = tf.all_variables() #获取所有var
    
    print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))
    
    sess.close()
    • 恢复模型(可以在一个文件夹下保存多次模型,checkpoint文件会自动记录所有模型名称和最后一次记录模型名称)
    tf.reset_default_graph()
    
    weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
    biases = tf.Variable(0, name="biases")
    
    saver = tf.train.Saver()
    sess = tf.Session()
    ckpt = tf.train.get_checkpoint_state(MODEL_DIR)
    saver.restore(sess, ckpt.model_checkpoint_path)
    
    print(sess.run([weights]))
    
    sess.close()
    • 微调模型(恢复之前训练模型的部分参数,加上新参数,继续训练)
    def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True):
        ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
        ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
        if not include_global_step:
            ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
        vars_in_ckpt = {}
        for variable_name, variable in sorted(variables.items()):
            if variable_name in ckpt_vars_to_shape_map:
                if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
                    vars_in_ckpt[variable_name] = variable
        return vars_in_ckpt
    
    tf.reset_default_graph()
    
    weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
    biases = tf.Variable(0, name="biases")
    other_weights = tf.Variable(tf.zeros([10, 10]))
    
    variables_to_init = tf.global_variables()
    variables_to_init_dict = {var.op.name: var for var in variables_to_init}
    available_var_map = get_variables_available_in_checkpoint(variables_to_init_dict,
        "%s/%s" % (MODEL_DIR, MODEL_NAME), include_global_step=False)
    tf.train.init_from_checkpoint("%s/%s" % (MODEL_DIR, MODEL_NAME), available_var_map)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    print(sess.run([weights]))
    
    sess.close()
    • 保存模型(二进制模型)
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    
    tf.reset_default_graph()
    
    saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
    sess = tf.Session()
    saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
    
    graph_out = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['weights'])
    with tf.gfile.GFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME), "wb") as output:
        output.write(graph_out.SerializeToString())
    
    sess.close()
    • 加载模型(二进制模型)
    tf.reset_default_graph()
    
    sess = tf.Session()
    with tf.gfile.FastGFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME),'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def,name='')
    sess.run(tf.global_variables_initializer())
    
    print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))
    
    sess.close()

    参考文献:

    https://blog.csdn.net/loveliuzz/article/details/81661875

    https://www.cnblogs.com/bbird/p/9951943.html

    https://blog.csdn.net/gzj_1101/article/details/80299610

  • 相关阅读:
    C#Windows 服务制作安装删除. 用户注销后,程序继续运行
    续实例解析SOCKET编程模型之异步通信篇(上) 代码
    Multiview和View控件 使用事例
    C# Socket 笔记
    检测密码强度的javascript
    最基本的Socket编程 C#版 [转]
    SQL Server CHARINDEX和PATINDEX详解
    C# 文件操作
    网站首页js幻灯片代码
    如何让div它们在一行显示
  • 原文地址:https://www.cnblogs.com/jhc888007/p/11620821.html
Copyright © 2011-2022 走看看