zoukankan      html  css  js  c++  java
  • tensorflow tf.train.Supervisor作用

    tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看.

    import tensorflow as tf
    import numpy as np
    import os
    log_path = r"D:Sourcemodellinear"
    log_name = "linear.ckpt"
    # Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
    x_data = np.random.rand(100).astype(np.float32)
    y_data = x_data * 0.1 + 0.3
    
    # Try to find values for W and b that compute y_data = W * x_data + b
    # (We know that W should be 0.1 and b 0.3, but TensorFlow will
    # figure that out for us.)
    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
    b = tf.Variable(tf.zeros([1]))
    y = W * x_data + b
    
    # Minimize the mean squared errors.
    loss = tf.reduce_mean(tf.square(y - y_data))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    
    # Before starting, initialize the variables.  We will 'run' this first.
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    
    # Launch the graph.
    sess = tf.Session()
    sess.run(init)
    
    if len(os.listdir(log_path)) != 0:  # 已经有模型直接读取
        saver.restore(sess, os.path.join(log_path, log_name))
    for step in range(201):
        sess.run(train)
        if step % 20 == 0:
            print(step, sess.run(W), sess.run(b))
    saver.save(sess, os.path.join(log_path, log_name))
    

    这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.

    import tensorflow as tf
    import numpy as np
    import os
    log_path = r"D:Sourcemodelsupervisor"
    log_name = "linear.ckpt"
    x_data = np.random.rand(100).astype(np.float32)
    y_data = x_data * 0.1 + 0.3
    
    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
    b = tf.Variable(tf.zeros([1]))
    y = W * x_data + b
    
    loss = tf.reduce_mean(tf.square(y - y_data))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    
    sv = tf.train.Supervisor(logdir=log_path, init_op=init)  # logdir用来保存checkpoint和summary
    saver = sv.saver  # 创建saver
    with sv.managed_session() as sess:  # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
        for i in range(201):
            sess.run(train)
            if i % 20 == 0:
                print(i, sess.run(W), sess.run(b))
        saver.save(sess, os.path.join(log_path, log_name))
    

    sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.

    参考资料

    1. tensorflow官方文档
    2. tensorflow学习笔记(二十二):Supervisor
  • 相关阅读:
    【英语天天读】Places and People
    【OpenCV学习】错误处理机制
    【英语天天读】Heart of a stranger 陌生的心灵
    【英语天天读】第一场雪
    【OpenCV学习】角点检测
    【英语天天读】Life is What We Make It
    【英语天天读】培养自信
    【英语天天读】Perseverance
    【OpenCV学习】cvseqpartition序列分类
    【英语天天读】自然
  • 原文地址:https://www.cnblogs.com/zhouyang209117/p/7088051.html
Copyright © 2011-2022 走看看