zoukankan      html  css  js  c++  java
  • 优化器Optimizer

    目前最流行的5种优化器:Momentum(动量优化)、NAG(Nesterov梯度加速)、AdaGrad、RMSProp、Adam,所有的优化算法都是在原始梯度下降算法的基础上增加惯性和环境感知因素进行持续优化

    Momentum优化

    momentum优化的一个简单思想:考虑物体运动惯性,想象一个保龄球在光滑表面滚下一个平缓的坡度,最开始会很慢,但是会迅速地恢复动力,直到达到最终速度(假设又一定的摩擦力核空气阻力)

    momentum优化关注以前的梯度是多少,公式:

    ((1)m leftarrow eta m + eta abla _ heta J( heta))

    ((2) heta leftarrow heta - m)

    超参数(eta)称为动量,其必须设置在0(高摩擦)和1(零摩擦)之间,默认值为0.9

    可以很容易地验证当梯度保持一个常量,最终速度(即权重的最大值)就等于梯度乘以学习率乘以(frac{1}{1-eta}),当(eta = 0.9)时,那么最终速度等于10倍梯度乘以学习率,所有momentum优化最终会比梯度下降快10倍,在不适用批量归一化的深度神经网络中,高层最终常会产生不同尺寸的输入,因此使用momentum优化会很有帮助,同时还会帮助跨过局部最优

    由于又动量,优化器可能会超调一点,然后返回,再超调,来回震荡多次后,最后稳定在最小值,这也是系统中要有一些摩擦的原因之一,它可以帮助摆脱震荡,从而加速收敛

    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9)

    Nesterov梯度加速

    公式:

    ((1)m leftarrow eta m + eta abla _ heta J( heta + eta m))

    ((2) heta leftarrow heta - m)

    与momentum唯一不同的是用( heta + eta m)来测量梯度,这个小调整有效是因为在通常情况下,动量矢量会指向正确的方向,所以在该方向相对远的地方使用梯度会比在原有地方更准确一些

    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9,use_nesterov=True)

    AdaGrad

    AdaGrad对于简单的二次问题一般表现都不错,但是在训练神经网络时却经常很早就停滞了,学习速率缩小得很多,在到达全局最优前算法就停止了,所以尽管tensorflow又AdagradOptimizer,也不要用它来训练深度神经网络
    公式:

    ((1)s leftarrow s + abla _ heta J( heta) otimes abla _ heta J( heta))

    ((2) heta leftarrow heta - eta abla _ heta J( heta) oslash sqrt{s+varepsilon})

    RMSProp

    AdaGrad降速太快而且没有办法收敛到全局最优,RMSProp算法却通过仅积累最近迭代中得梯度(而非从训练开始得梯度)解决这个问题,它通在第一步使用指数衰减开实现
    公式:

    ((1)s leftarrow eta s + (1-eta) abla _ heta J( heta) otimes abla _ heta J( heta))

    ((2) heta leftarrow heta - eta abla _ heta J( heta) oslash sqrt{s+varepsilon})

    衰减率(eta)通常为0.9

    optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,momentum=0.9,decay=0.9,epsilon=0.9)

    除去非常简单得问题,这个优化器得表现几乎全部优于AdaGrad,同时表现也基本都优于Momentum优化和NAG,事实上在Adam优化出现之前,它是众多研究者所推荐得优化算法

    Adam优化

    Adam代表了自适应力矩估计,集合了Momentum优化和RmsProp的想法,类似Momentum优化,它会跟踪过去梯度的指数衰减平均值,同时也类似RMSProp,它会跟踪过去梯度平方的指数衰减平均值,

    Adam算法:

    ((1)m leftarrow eta_1 m + (1-eta_i) abla _ heta J( heta))

    ((2)s leftarrow eta_2s +(1-eta_2) abla _ heta J( heta) otimes abla _ heta J( heta))

    ((3)m leftarrow frac{m}{1-eta_1^T})

    ((4)s leftarrow frac{s}{1-eta_2^T})

    ((5) heta leftarrow heta - eta moslash sqrt{s+varepsilon})

    注:T表示迭代次数(从1开始)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

    使用Adam优化器对mnist进行测试

    import tensorflow as tf
    from tensorflow.contrib.layers import fully_connected,batch_norm
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
    
    tf.reset_default_graph()
    n_input = 784
    n_hidden1 = 300
    n_hidden2 = 100
    n_output = 10
    
    X = tf.placeholder(tf.float32,shape=(None,n_input),name='X')
    Y = tf.placeholder(tf.int64,shape=(None,10),name='Y')
    #归一化参数
    is_training = tf.placeholder(tf.bool,shape=(),name='is_training')
    bn_params = {'is_training':is_training,'decay':0.99,'updates_collections':None}
    
    with tf.name_scope('dnn'):
        with tf.contrib.framework.arg_scope([fully_connected],normalizer_fn=batch_norm,normalizer_params=bn_params):
            hidden1 = fully_connected(X,n_hidden1,activation_fn=tf.nn.elu,scope='hidden1')
            hidden2 = fully_connected(hidden1,n_hidden2,activation_fn=tf.nn.elu,scope='hidden2')
            y_prab = fully_connected(hidden2,n_output,activation_fn=tf.nn.softmax,scope='output')
    with tf.name_scope('train'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y,logits=y_prab))
        learning_rate = tf.placeholder(tf.float32,shape=(),name='learning_rate')
        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)
    with tf.name_scope('accuracy'):
        prab_bool = tf.equal(tf.argmax(y_prab,1),tf.argmax(Y,1))
        accuracy = tf.reduce_mean(tf.cast(prab_bool,tf.float32))
    with tf.name_scope('tensorboard_mnist'):
        file_writer = tf.summary.FileWriter('./tensorboard/',tf.get_default_graph())
        accuracy_summary = tf.summary.scalar('accuracy',accuracy)
    with tf.name_scope('saver'):
        saver = tf.train.Saver()
    with tf.name_scope('collection'):
        tf.add_to_collection('logits',y_prab)
        
    epoches = 20
    batch_size = 100
    n_batches = mnist.train.num_examples // batch_size
    rate = 0.1
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(epoches):
            for batch in range(n_batches):
                x_batch,y_batch = mnist.train.next_batch(batch_size)
                sess.run(optimizer,feed_dict={X:x_batch,Y:y_batch,learning_rate:rate,is_training:True})
            result = sess.run([accuracy,accuracy_summary],feed_dict={X:mnist.test.images,Y:mnist.test.labels,
                                                                     learning_rate:rate,is_training:False})
            
            file_writer.add_summary(result[1],epoch)
            print('epoch:{},accuracy:{}'.format(epoch,result[0]))
        saver.save(sess,'./model/model_final.ckpt',global_step=5)
        print('stop')
    
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    epoch:0,accuracy:0.945900022983551
    epoch:1,accuracy:0.9574999809265137
    epoch:2,accuracy:0.9635000228881836
    epoch:3,accuracy:0.9693999886512756
    epoch:4,accuracy:0.970300018787384
    epoch:5,accuracy:0.9704999923706055
    epoch:6,accuracy:0.9758999943733215
    epoch:7,accuracy:0.9757999777793884
    epoch:8,accuracy:0.9768999814987183
    epoch:9,accuracy:0.9783999919891357
    epoch:10,accuracy:0.9783999919891357
    epoch:11,accuracy:0.9642999768257141
    epoch:12,accuracy:0.9779999852180481
    epoch:13,accuracy:0.9799000024795532
    epoch:14,accuracy:0.9760000109672546
    epoch:15,accuracy:0.977400004863739
    epoch:16,accuracy:0.9819999933242798
    epoch:17,accuracy:0.9781000018119812
    epoch:18,accuracy:0.9661999940872192
    epoch:19,accuracy:0.9779000282287598
    stop
  • 相关阅读:
    bzoj 1053
    bzoj 1004 组合
    字符串哈希
    bzoj 1015 并查集
    bzoj 1003 最短路+dp
    HDU 4352 数位dp
    bzoj 1879 状压dp
    codeforces 55D 数位dp
    Codeforces 830B
    组合计数 && Stirling数
  • 原文地址:https://www.cnblogs.com/xiaobingqianrui/p/10756046.html
Copyright © 2011-2022 走看看