zoukankan      html  css  js  c++  java
  • 吴裕雄 PYTHON 神经网络——TENSORFLOW 学习率的设置

    import tensorflow as tf
    TRAINING_STEPS = 10
    LEARNING_RATE = 1
    x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
    y = tf.square(x)
    
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(y)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(TRAINING_STEPS):
            sess.run(train_op)
            x_value = sess.run(x)
            print( "After %s iteration(s): x%s is %f."% (i+1, i+1, x_value) )

    TRAINING_STEPS = 1000
    LEARNING_RATE = 0.001
    x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
    y = tf.square(x)
    
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(y)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(TRAINING_STEPS):
            sess.run(train_op)
            if i % 100 == 0: 
                x_value = sess.run(x)
                print("After %s iteration(s): x%s is %f."% (i+1, i+1, x_value))

    TRAINING_STEPS = 100
    global_step = tf.Variable(0)
    LEARNING_RATE = tf.train.exponential_decay(0.1, global_step, 1, 0.96, staircase=True)
    
    x = tf.Variable(tf.constant(5, dtype=tf.float32), name="x")
    y = tf.square(x)
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(y, global_step=global_step)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(TRAINING_STEPS):
            sess.run(train_op)
            if i % 10 == 0:
                LEARNING_RATE_value = sess.run(LEARNING_RATE)
                x_value = sess.run(x)
                print ("After %s iteration(s): x%s is %f, learning rate is %f."% (i+1, i+1, x_value, LEARNING_RATE_value))

  • 相关阅读:
    matplot 代码实例2
    sklearn 线性模型使用入门
    python 之 决策树分类算法
    Leetcode 之Simplify Path @ python
    协同过滤CF算法之入门
    linux 下 rpc python 实例之使用XML-RPC进行远程文件共享
    Linux rpc 编程最简单实例
    Opencv 入门学习之图片人脸识别
    Django1.7开发博客
    Opencv 入门学习1
  • 原文地址:https://www.cnblogs.com/tszr/p/10874418.html
Copyright © 2011-2022 走看看