zoukankan      html  css  js  c++  java
  • tensorflow中的学习率调整策略

    通常为了模型能更好的收敛,随着训练的进行,希望能够减小学习率,以使得模型能够更好地收敛,找到loss最低的那个点.

    tensorflow中提供了多种学习率的调整方式.在https://www.tensorflow.org/api_docs/python/tf/compat/v1/train搜索decay.可以看到有多种学习率的衰减策略.

    • cosine_decay
    • exponential_decay
    • inverse_time_decay
    • linear_cosine_decay
    • natural_exp_decay
    • noisy_linear_cosine_decay
    • polynomial_decay

    本文介绍两种学习率衰减策略,指数衰减和多项式衰减.

    tf.compat.v1.train.exponential_decay(
        learning_rate,
        global_step,
        decay_steps,
        decay_rate,
        staircase=False,
        name=None
    )
    
    

    learning_rate 初始学习率
    global_step 当前总共训练多少个迭代
    decay_steps 每xxx steps后变更一次学习率
    decay_rate 用以计算变更后的学习率
    staircase: global_step/decay_steps的结果是float型还是向下取整

    学习率的计算公式为:decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)

    我们用一段测试代码来绘制一下学习率的变化情况.

    #coding=utf-8
    import matplotlib.pyplot as plt
    import tensorflow as tf
    
    x=[]
    y=[]
    N = 200 #总共训练200个迭代
    
    num_epoch = tf.Variable(0, name='global_step', trainable=False)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for num_epoch in range(N):
            ##初始学习率0.5,每10个迭代更新一次学习率.
            learing_rate_decay = tf.train.exponential_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, decay_rate=0.9, staircase=False)
            learning_rate = sess.run([learing_rate_decay])
            y.append(learning_rate)
    
    #print(y)
    
    x = range(N)
    fig = plt.figure()
    ax.set_xlabel('step')
    ax.set_ylabel('learing rate')
    plt.plot(x, y, 'r', linewidth=2)
    plt.show()
    

    结果如图:

    • 多项式衰减
    tf.compat.v1.train.polynomial_decay(
        learning_rate,
        global_step,
        decay_steps,
        end_learning_rate=0.0001,
        power=1.0,
        cycle=False,
        name=None
    )
    

    设定一个初始学习率,一个终止学习率,然后线性衰减.cycle控制衰减到end_learning_rate后是否保持这个最小学习率不变,还是循环往复. 过小的学习率会导致收敛到局部最优解,循环往复可以一定程度上避免这个问题.
    根据cycle是否为true,其计算方式不同,如下:

    #coding=utf-8
    import matplotlib.pyplot as plt
    import tensorflow as tf
    
    x=[]
    y=[]
    z=[]
    N = 200 #总共训练200个迭代
    
    num_epoch = tf.Variable(0, name='global_step', trainable=False)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for num_epoch in range(N):
            ##初始学习率0.5,每10个迭代更新一次学习率.
            learing_rate_decay = tf.train.polynomial_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, end_learning_rate=0.0001, cycle=False)
            learning_rate = sess.run([learing_rate_decay])
            y.append(learning_rate)
            
            learing_rate_decay2 = tf.train.polynomial_decay(learning_rate=0.5, global_step=num_epoch, decay_steps=10, end_learning_rate=0.0001, cycle=True)
            learning_rate2 = sess.run([learing_rate_decay2])
            z.append(learning_rate2)
    #print(y)
    
    x = range(N)
    fig = plt.figure()
    ax.set_xlabel('step')
    ax.set_ylabel('learing rate')
    plt.plot(x, y, 'r', linewidth=2)
    plt.plot(x, z, 'g', linewidth=2)
    plt.show()
    

    绘图结果如下:

    cycle为false时对应红线,学习率下降到0.0001后不再下降. cycle=true时,下降到0.0001后再突变到一个更大的值,在继续衰减,循环往复.

    在代码里,通常通过参数去控制不同的学习率策略,例如

    def _configure_learning_rate(num_samples_per_epoch, global_step):
      """Configures the learning rate.
    
      Args:
        num_samples_per_epoch: The number of samples in each epoch of training.
        global_step: The global_step tensor.
    
      Returns:
        A `Tensor` representing the learning rate.
    
      Raises:
        ValueError: if
      """
      # Note: when num_clones is > 1, this will actually have each clone to go
      # over each epoch FLAGS.num_epochs_per_decay times. This is different
      # behavior from sync replicas and is expected to produce different results.
      decay_steps = int(num_samples_per_epoch * FLAGS.num_epochs_per_decay /
                        FLAGS.batch_size)
    
      if FLAGS.sync_replicas:
        decay_steps /= FLAGS.replicas_to_aggregate
    
      if FLAGS.learning_rate_decay_type == 'exponential':
        return tf.train.exponential_decay(FLAGS.learning_rate,
                                          global_step,
                                          decay_steps,
                                          FLAGS.learning_rate_decay_factor,
                                          staircase=True,
                                          name='exponential_decay_learning_rate')
      elif FLAGS.learning_rate_decay_type == 'fixed':
        return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
      elif FLAGS.learning_rate_decay_type == 'polynomial':
        return tf.train.polynomial_decay(FLAGS.learning_rate,
                                         global_step,
                                         decay_steps,
                                         FLAGS.end_learning_rate,
                                         power=1.0,
                                         cycle=False,
                                         name='polynomial_decay_learning_rate')
      else:
        raise ValueError('learning_rate_decay_type [%s] was not recognized' %
                         FLAGS.learning_rate_decay_type)
    

    推荐一篇:https://blog.csdn.net/dcrmg/article/details/80017200 对各种学习率衰减策略描述的很详细.并且都有配图,可以很直观地看到各种衰减策略下学习率变换情况.

  • 相关阅读:
    装饰器 无惨固定模式 和 有参装饰器的固定模式
    匿名函数
    字典生成式
    列表生成式
    Centos7安装配置apache-tomcat-8.5.16.tar.gz
    Centos7安装配置jdk-8u151-linux-x64.tar.gz
    Linux CentOS7源码安装配置mysql-5.7.17-linux-glibc2.5-x86_64.tar.gz
    VirtualBox新建Centos7虚拟系统
    vmware workstation 10的安装
    redhat linux rpm包安装配置mysql数据库
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/11883448.html
Copyright © 2011-2022 走看看