zoukankan      html  css  js  c++  java
  • tensorflow学习笔记1:导出和加载模型

    用一个非常简单的例子学习导出和加载模型;

    导出

    写一个y=a*x+b的运算,然后保存graph;

    import tensorflow as tf
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    
    with tf.Session() as sess:
        a = tf.Variable(5.0, name='a')
        x = tf.Variable(6.0, name='x')
        b = tf.Variable(3.0, name='b')
        y = tf.add(tf.multiply(a,x),b, name="y")
    
        tf.global_variables_initializer().run()
        
        print (a.eval()) # 5.0
        print (x.eval()) # 6.0
        print (b.eval()) # 3.0
        print (y.eval()) # 33.0
    
        graph = convert_variables_to_constants(sess, sess.graph_def, ["y"])
        #writer = tf.summary.FileWriter("logs/", graph)
        tf.train.write_graph(graph, 'models/', 'test_graph.pb', as_text=False)
    

    运行

    在models目录下生成了test_graph.pb;

    注:convert_variables_to_constants操作是将模型参数froze(保存)进graph中,这时的graph相当于是sess.graph_def + checkpoint,即有模型结构也有模型参数;

    加载

     只加载,获取各个变量的值

    import tensorflow as tf
    from tensorflow.python.platform import gfile
    
    with gfile.FastGFile("models/test_graph.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        output = tf.import_graph_def(graph_def, return_elements=['a:0', 'x:0', 'b:0','y:0'])
        #print(output)
        
    with tf.Session() as sess:
        result = sess.run(output)
        print (result)
    

      

    运行看以看到原本保存的结果(因为几个变量都已经带入模型,又从模型中加载了出来)

    加载的时候修改变量值

    5*2+3=13,结果正确

    运行时修改变量值

    加载时用一个占位符替掉x常量,在session运行时再给占位符填值;

    5*3+3=18,也正确

    修改计算结果

    偷偷把结果给改了会怎么样?

    呵呵,不知原因为何;以后钻进代码了再说;

    参考:

    https://www.sohu.com/a/233679628_468681

    http://blog.163.com/wujiaxing009@126/blog/static/7198839920174125748893/

  • 相关阅读:
    SDN作业(4)
    SDN作业(3)
    第一次个人编程作业
    SDN作业(2)
    SDN作业(1)
    第一次博客作业
    浅谈闭包
    预编译And作用域链
    定时器
    window事件
  • 原文地址:https://www.cnblogs.com/ZisZ/p/9144859.html
Copyright © 2011-2022 走看看