zoukankan      html  css  js  c++  java
  • 转载:tensorflow保存训练后的模型

    训练完一个模型后,为了以后重复使用,通常我们需要对模型的结果进行保存。如果用Tensorflow去实现神经网络,所要保存的就是神经网络中的各项权重值。建议可以使用Saver类保存和加载模型的结果。

    1、使用tf.train.Saver.save()方法保存模型

    tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)

    • sess: 用于保存变量操作的会话。
    • save_path: String类型,用于指定训练结果的保存路径。
    • global_step: 如果提供的话,这个数字会添加到save_path后面,用于构建checkpoint文件。这个参数有助于我们区分不同训练阶段的结果。

    2、使用tf.train.Saver.restore方法价值模型

    tf.train.Saver.restore(sess, save_path)

    • sess: 用于加载变量操作的会话。
    • save_path: 同保存模型是用到的的save_path参数。

    下面通过一个代码演示这两个函数的使用方法

    import tensorflow as tf
    import numpy as np
    
    x = tf.placeholder(tf.float32, shape=[None, 1])
    y = 4 * x + 4
    
    w = tf.Variable(tf.random_normal([1], -1, 1))
    b = tf.Variable(tf.zeros([1]))
    y_predict = w * x + b
    
    
    loss = tf.reduce_mean(tf.square(y - y_predict))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    
    isTrain = False
    train_steps = 100
    checkpoint_steps = 50
    checkpoint_dir = ''
    
    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
    x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if isTrain:
            for i in xrange(train_steps):
                sess.run(train, feed_dict={x: x_data})
                if (i + 1) % checkpoint_steps == 0:
                    saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
        else:
            ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                pass
            print(sess.run(w))
            print(sess.run(b))
  • 相关阅读:
    09-异常处理-成绩判断异常
    继承与多态———动手动脑
    课下作业04-2String的使用方法
    课下作业04-1字符串加密
    课下作业03-2动手动脑及验证
    课下作业03-1请写一个类,在任何时候都可以向它查询“你已经创建了多少个对象?
    课下作业02-动手动脑
    Myschool试题
    使用ADO.NET
    模糊查询和聚合函数
  • 原文地址:https://www.cnblogs.com/txq157/p/7242385.html
Copyright © 2011-2022 走看看