zoukankan      html  css  js  c++  java
  • TensorFlow 模型保存和导入、加载

    在TensorFlow中,保存模型与加载模型所用到的是tf.train.Saver()这个类。我们一般的想法就是,保存模型之后,在另外的文件中重新将模型导入,我可以利用模型中的operation和variable来测试新的数据。


    什么是TensorFlow中的模型

    首先,我们先来理解一下TensorFlow里面的模型是什么。在保存模型后,一般会出现下面四个文件:

    这里写图片描述

    meta graph:保存了TensorFlow的graph。包括all variables,operations,collections等等。这个文件就是上面的.meta文件。

    checkpoint files:二进制文件,保存了所有weights,biases,gradient and all the other variables的值。也就是上图中的.data-00000-of-00001和.index文件。.data文件包含了所有的训练变量。以前的TensorFlow版本是一个ckpt文件,现在就是这两个文件了。与此同时,Tensorflow还有一个名为checkpoint的文件,只保存最新检查点文件的记录,即最新的保存路径。


    保存一个TensorFlow的模型

    在TensorFlow中,如果想保存一个图(graph)或者所有的参数的值,那么就需要用到tf.train.Saver()这个类。

    import tensorflow as tf
    saver = tf.train.Saver()
    sess = tf.Session()
    saver.save(sess, 'my_test_model')
    

      

    上面这段代码最后一句就是保存模型,第二个参数是一个路径(包含模型的名字)。当然还有其他的形参,我们接下来讲:
    global_step:给一个数字,用于保存文件时tensorflow帮你命名。主要是说明了迭代多次后保存了。
    write_meta_graph:bool型,说明要不要把TensorFlow的图保存下来。
    关于save函数更多的说明请参考:
    https://www.tensorflow.org/api_docs/python/tf/train/Saver#save

    例子:

    import tensorflow as tf
    w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
    w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my_test_model')
    
    # This will save following files in Tensorflow v >= 0.11
    # my_test_model.data-00000-of-00001
    # my_test_model.index
    # my_test_model.meta
    # checkpoint
    

      


    导入一个训练好的模型

    前门讲了如何保存一个模型,现在要把模型导出来用了。

    训练好的模型,.meta文件中已经保存了整个graph,我们无需重建,只要导入.meta文件即可。

    with tf.Session() as sess:
            new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')#这个函数就是讲graph导出来
    

      

    下面用一个例子来说明一下,直接上完整代码:

    第一个文件,训练模型并保存模型:

    #定义模型
    X = tf.placeholder(tf.float32,shape = [None,x_dim],name = 'X')
    Y = tf.placeholder(tf.float32,shape = [None,1], name = 'Y')
    W = tf.Variable(tf.random_normal([x_dim,1]),name='weight')
    b = tf.Variable(tf.random_normal([1]),name='bias')
    hypothesis = tf.sigmoid(tf.matmul(X,W)+b)
    cost = -tf.reduce_mean(Y*tf.log(hypothesis) + (1-Y)*tf.log(1-hypothesis))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    train = optimizer.minimize(cost)
    
    #假如想要保存hypothesis和cost,以便在保存模型后,重新导入模型时可以使用。
    tf.add_to_collection('hypothesis',hypothesis)#必须有个名字,即第一个参数
    tf.add_to_collection('cost',cost)
    
    mysaver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    for epoch in range(50):
        avg_cost, _ = sess.run([cost,train],feed_dict = {X:x_data,Y:y_data})
    
    mysaver.save(sess, '../model/model_LR_test') #保存模型
    

      

    第二个文件,加载模型,并利用训练好的模型预测:

    sess = tf.Session()
    #本来我们需要重新像上一个文件那样重新构建整个graph,但是利用下面这个语句就可以加载整个graph了,方便
    new_saver = tf.train.import_meta_graph('../model/model_LR_test.meta')
    new_saver.restore(sess,'../model/model_LR_test')#加载模型中各种变量的值,注意这里不用文件的后缀
    
    #对应第一个文件的add_to_collection()函数
    hyp = tf.get_collection('hypothesis')[0] #返回值是一个list,我们要的是第一个,这也说明可以有多个变量的名字一样。
    
    graph = tf.get_default_graph() 
    X = graph.get_operation_by_name('X').outputs[0]#为了将placeholder加载出来
    
    pred = sess.run(hyp,feed_dict = {X:x_valid})
    print('auc:',auc(y_valid,pred))

    是这样的,使用TensorFlow构建模型的时候,如果一些operation想要在加载模型时用到。那么需要使用add_to_collection()函数来将operation存起来。然后再加载模型后可以调用。当然tensorflow无论怎样都需要给每个东西一个名字(string型),只有通过名字才可以找到对应的operation。

  • 相关阅读:
    angularjs基础——控制器
    angularjs基础——变量绑定
    mysql 小数处理
    centos无法联网解决方法
    mysql 按 in 顺序排序
    html5 file 自定义文件过滤
    淘宝、天猫装修工具
    MapGis如何实现WebGIS分布式大数据存储的
    CentOS
    PHP与Python哪个做网站产品好?
  • 原文地址:https://www.cnblogs.com/wuzaipei/p/10478830.html
Copyright © 2011-2022 走看看