zoukankan      html  css  js  c++  java
  • tensorflow tf.train.Supervisor作用

    tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看.

    import tensorflow as tf
    import numpy as np
    import os
    log_path = r"D:Sourcemodellinear"
    log_name = "linear.ckpt"
    # Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
    x_data = np.random.rand(100).astype(np.float32)
    y_data = x_data * 0.1 + 0.3
    
    # Try to find values for W and b that compute y_data = W * x_data + b
    # (We know that W should be 0.1 and b 0.3, but TensorFlow will
    # figure that out for us.)
    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
    b = tf.Variable(tf.zeros([1]))
    y = W * x_data + b
    
    # Minimize the mean squared errors.
    loss = tf.reduce_mean(tf.square(y - y_data))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    
    # Before starting, initialize the variables.  We will 'run' this first.
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    
    # Launch the graph.
    sess = tf.Session()
    sess.run(init)
    
    if len(os.listdir(log_path)) != 0:  # 已经有模型直接读取
        saver.restore(sess, os.path.join(log_path, log_name))
    for step in range(201):
        sess.run(train)
        if step % 20 == 0:
            print(step, sess.run(W), sess.run(b))
    saver.save(sess, os.path.join(log_path, log_name))
    

    这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.

    import tensorflow as tf
    import numpy as np
    import os
    log_path = r"D:Sourcemodelsupervisor"
    log_name = "linear.ckpt"
    x_data = np.random.rand(100).astype(np.float32)
    y_data = x_data * 0.1 + 0.3
    
    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
    b = tf.Variable(tf.zeros([1]))
    y = W * x_data + b
    
    loss = tf.reduce_mean(tf.square(y - y_data))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    
    sv = tf.train.Supervisor(logdir=log_path, init_op=init)  # logdir用来保存checkpoint和summary
    saver = sv.saver  # 创建saver
    with sv.managed_session() as sess:  # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
        for i in range(201):
            sess.run(train)
            if i % 20 == 0:
                print(i, sess.run(W), sess.run(b))
        saver.save(sess, os.path.join(log_path, log_name))
    

    sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.

    参考资料

    1. tensorflow官方文档
    2. tensorflow学习笔记(二十二):Supervisor
  • 相关阅读:
    CentOS下date命令
    spring-data-redis --简单的用spring-data-redis
    Unable to Rebuild JIRA Index
    JIRA Cannot Start Due to 'unable to clean the cache directory: /opt/jira/plugins/.osgi-plugins/felix'
    Java compiler level does not match the version of the installed Java project facet.
    maven scope含义的说明
    maven 下载 源码和javadoc命令
    Redis 入门第一发
    mysql 1194 – Table ‘tbl_video_info’ is marked as crashed and should be repaired 解决方法
    tomcat用redis做session共享
  • 原文地址:https://www.cnblogs.com/zhouyang209117/p/7088051.html
Copyright © 2011-2022 走看看