zoukankan      html  css  js  c++  java
  • TensorFlow——学习率衰减的使用方法

    在TensorFlow的优化器中, 都要设置学习率。学习率是在精度和速度之间找到一个平衡:

    学习率太大,训练的速度会有提升,但是结果的精度不够,而且还可能导致不能收敛出现震荡的情况。

    学习率太小,精度会有所提升,但是训练的速度慢,耗费较多的时间。

    因而我们可以使用退化学习率,又称为衰减学习率。它的作用是在训练的过程中,对学习率的值进行衰减,训练到达一定程度后,使用小的学习率来提高精度。

    在TensorFlow中的方法如下:tf.train.exponential_decay(),该方法的参数如下:

    learning_rate, 初始的学习率的值

    global_step, 迭代步数变量

    decay_steps, 带迭代多少次进行衰减

    decay_rate, 迭代decay_steps次衰减的值

    staircase=False, 默认为False,为True则不衰减

    例如

    tf.train.exponential_decay(initial_learning_rate, global_step=global_step, decay_steps=1000, decay_rate=0.9)表示没经过1000次的迭代,学习率变为原来的0.9。

    增大批次处理样本的数量也可以起到退化学习率的作用。

    下面我们写了一个例子,每迭代10次,则较小为原来的0.5,代码如下:

    import tensorflow as tf
    import numpy as np
    
    global_step = tf.Variable(0, trainable=False)
    initial_learning_rate = 0.1
    
    learning_rate = tf.train.exponential_decay(initial_learning_rate,
                                               global_step=global_step,
                                               decay_steps=10,
                                               decay_rate=0.5)
    
    opt = tf.train.GradientDescentOptimizer(learning_rate)
    add_global = global_step.assign_add(1)
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        print(sess.run(learning_rate))
    
        for i in range(50):
            g, rate = sess.run([add_global, learning_rate])
            print(g, rate)

    下面是程序的结果,我们发现没10次就变为原来的一般:

    随后,又在MNIST上面进行了测试,发现使用学习率衰减使得准确率有较好的提升。代码如下:

    import tensorflow as tf
    import numpy as np
    from tensorflow.examples.tutorials.mnist import input_data
    import matplotlib.pyplot as plt
    
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    tf.reset_default_graph()
    
    x = tf.placeholder(tf.float32, [None, 784])
    y = tf.placeholder(tf.float32, [None, 10])
    
    w = tf.Variable(tf.random_normal([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    
    pred = tf.matmul(x, w) + b
    pred = tf.nn.softmax(pred)
    
    cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
    
    global_step = tf.Variable(0, trainable=False)
    initial_learning_rate = 0.1
    
    learning_rate = tf.train.exponential_decay(initial_learning_rate,
                                               global_step=global_step,
                                               decay_steps=1000,
                                               decay_rate=0.9)
    
    
    opt = tf.train.GradientDescentOptimizer(learning_rate)
    add_global = global_step.assign_add(1)
    
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    
    training_epochs = 50
    batch_size = 100
    
    display_step = 1
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
    
        for epoch in range(training_epochs):
            avg_cost = 0
            total_batch = int(mnist.train.num_examples/batch_size)
            for i in range(total_batch):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                _, c, add, rate = sess.run([optimizer, cost, add_global, learning_rate], feed_dict={x:batch_xs, y:batch_ys})
                avg_cost += c / total_batch
    
            if (epoch + 1) % display_step == 0:
                print('epoch= ', epoch+1, ' cost= ', avg_cost, 'add_global=', add, 'rate=', rate)
        print('finished')
    
        correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))

    在使用衰减学习率我们最后的精度达到0.8897,在使用固定的学习率时,精度只有0.8586。

  • 相关阅读:
    点击事件在label标签执行2次解决办法
    关于有偿提供拼图响应式后台的通知.
    jquery鼠标悬停导航下划线滑出效果
    javascript深层解释。
    创业者靠什么成功?
    怎么样才是理想的程序员?
    web总结。
    深入认识盒子模型 display:inline-block BFC 和haslayout !
    25种用户十秒离开你网站的原因!
    JavaScript 引用类型数组!
  • 原文地址:https://www.cnblogs.com/baby-lily/p/10962574.html
Copyright © 2011-2022 走看看