zoukankan      html  css  js  c++  java
  • 吴裕雄 PYTHON 神经网络——TENSORFLOW 学习率的设置

    import tensorflow as tf
    TRAINING_STEPS = 10
    LEARNING_RATE = 1
    x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
    y = tf.square(x)
    
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(y)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(TRAINING_STEPS):
            sess.run(train_op)
            x_value = sess.run(x)
            print( "After %s iteration(s): x%s is %f."% (i+1, i+1, x_value) )

    TRAINING_STEPS = 1000
    LEARNING_RATE = 0.001
    x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
    y = tf.square(x)
    
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(y)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(TRAINING_STEPS):
            sess.run(train_op)
            if i % 100 == 0: 
                x_value = sess.run(x)
                print("After %s iteration(s): x%s is %f."% (i+1, i+1, x_value))

    TRAINING_STEPS = 100
    global_step = tf.Variable(0)
    LEARNING_RATE = tf.train.exponential_decay(0.1, global_step, 1, 0.96, staircase=True)
    
    x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
    y = tf.square(x)
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(y, global_step=global_step)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(TRAINING_STEPS):
            sess.run(train_op)
            if i % 10 == 0:
                LEARNING_RATE_value = sess.run(LEARNING_RATE)
                x_value = sess.run(x)
                print ("After %s iteration(s): x%s is %f, learning rate is %f."% (i+1, i+1, x_value, LEARNING_RATE_value))

  • 相关阅读:
    20171229
    对象关系型数据库管理系统(PostgresQL )
    CDN技术之--集群服务与负载均衡
    CDN技术之-介绍
    oracle不同用户间访问表不添加用户名(模式)前缀
    ora-28000 the account is locked
    CDN技术之--该技术概述
    CDN技术之--内容缓存工作原理
    PL/SQL题型代码示例
    在java中使用solr7.2.0 新旧版本创建SolrClient对比
  • 原文地址:https://www.cnblogs.com/tszr/p/10874418.html
Copyright © 2011-2022 走看看