zoukankan      html  css  js  c++  java
  • TensorFlow:tf.train.Saver()模型保存与恢复

    1.保存

    将训练好的模型参数保存起来,以便以后进行验证或测试。tf里面提供模型保存的是tf.train.Saver()模块。

    模型保存,先要创建一个Saver对象:如

    saver=tf.train.Saver()

    在创建这个Saver对象的时候,有一个参数经常会用到,max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐,如:

    saver=tf.train.Saver(max_to_keep=0)

    当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

    saver=tf.train.Saver(max_to_keep=1)

    创建完saver对象后,就可以保存训练好的模型了,如:

    saver.save(sess,‘ckpt/mnist.ckpt',global_step=step)

    第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

    saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
    ...
    saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

    2.举例

    import tensorflow as tf
    import numpy as np
    x = tf.placeholder(tf.float32, shape=[None, 1])
    y = 4 * x + 4
    w = tf.Variable(tf.random_normal([1], -1, 1))
    b = tf.Variable(tf.zeros([1]))
    y_predict = w * x + b
    loss = tf.reduce_mean(tf.square(y - y_predict))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    isTrain = False
    train_steps = 100
    checkpoint_steps = 50
    checkpoint_dir = ''
    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if isTrain:
            for i in xrange(train_steps):
                sess.run(train, feed_dict={x: x_data})
                if (i + 1) % checkpoint_steps == 0:
                    saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
        else:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                pass
            print(sess.run(w))
            print(sess.run(b)) 

    3.恢复

    用saver.restore()方法恢复变量:

    saver.restore(sess,'ckpt.model_checkpoint_path')

    sess:表示当前会话,之前保存的结果将被加载入这个会话;

    ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

    转载:

    【1】https://www.cnblogs.com/denny402/p/6940134.html

    【2】https://blog.csdn.net/u011500062/article/details/51728830

  • 相关阅读:
    编译mysql4.0时候出现错误提示checking "LinuxThreads"... "Not found"
    rhel5上使用源代码安装mysql4.0.x
    T400|X220打开AHCI的正确步骤
    MySQL主从同步、读写分离配置步骤
    Icc编译MySQL性能调研
    Java操作Hbase进行建表、删表以及对数据进行增删改查,条件查询
    hadoop 异常记录 ERROR: org.apache.hadoop.hbase.MasterNotRunningException: Retried 7 times
    创建代理访问.NET WebServices
    国际:2007年最令人失望的九大新兴技术
    黑客惊天发现:苹果能监视iPhone的一举一动
  • 原文地址:https://www.cnblogs.com/chamie/p/8780508.html
Copyright © 2011-2022 走看看