zoukankan      html  css  js  c++  java
  • TensorFlow 的学习效率的指数衰减法

    train_step = tf.train.GradientDescentOptimizer(x)

    在之前的几个例子中都出现了如上代码。

    这个优化算法的参数就是学习效率。那么这个学习效率是什么意思,到底取什么样的值比较好呢?

    之前已经讲过,优化算法会反向修改函数中设置为Variable的变量值,使得误差逐步缩小。而这里的学习效率就是Variable更新变化的幅度。

    如果幅度过大,参数就很可能跨过最优值,最后在最优值的两侧来回浮动。

    如果幅度太小,又会大大的增加学习时间。

    比较理想的做法是,在学习初期,将这个值设的大一些,当逐渐靠近最优解的时候,逐渐缩小学习效率使得获得的值更加靠近最优值。

    TensorFlow就为我们提供了这种方法:指数衰减法

    tf.train.exponential_decay

    它实现的功能类似如下代码

    decayed_learning_rate = learning_rate * decay_rate^(global_step/decay_steps)
    • decayed_learning_rate:  优化后的每一轮的学习效率。
    • learning_rate:               最初设置的学习效率。
    • decay_rate:                  衰减系数。
    • decay_steps:                衰减速度。

    将demo1的代码稍作修改,加入今天我们讲到的函数,并且以图形化的方式输出。

    我们会看到,原来0.1的学习效率,所产生的线与我们正确线在有一段距离的地方开始上下浮动,不再靠近我们正确值的线段。

    而给学习效率加上指数衰减算法后,很快我们的生成的线段就与正确值几乎重合了。

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    x_data = np.random.rand(50).astype(np.float32)
    y_data = x_data * 0.1 + 0.3;
    ###
    Weights = tf.Variable(tf.random_uniform([1],-1.0,1.0))
    biases = tf.Variable(tf.zeros([1]))
    
    y = Weights*x_data + biases
    
    loss=tf.reduce_mean(tf.square(y-y_data))
    
    global_step = tf.Variable(0)
    
    learning_rate = tf.train.exponential_decay(0.1,global_step,100,0.96,staircase=True)
    learning_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
    
    init = tf.global_variables_initializer()
    ###
    
    sess = tf.Session()
    sess.run(init)
    
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    
    plt.ion()
    plt.show()
    
    
    for step in range(300):
        sess.run(learning_step)
        if step % 20 == 0:
            y_value=sess.run(y)
            ax.scatter(x_data,y_data)
            ax.scatter(x_data,y_value)
            plt.pause(1)
  • 相关阅读:
    SpringCloud分布式开发五大神兽
    Spring Cloud 架构 五大神兽的功能
    kafka 基础知识梳理-kafka是一种高吞吐量的分布式发布订阅消息系统
    ETL工具之Kettle的简单使用一(不同数据库之间的数据抽取-转换-加载)
    libjson 编译和使用
    一个用C++写的Json解析与处理库
    DB-library 常用函数
    什么是C++虚函数、虚函数的作用和使用方法
    C++用iconv进行页面字符转换
    QT学习:c++解析html相关
  • 原文地址:https://www.cnblogs.com/guolaomao/p/8004607.html
Copyright © 2011-2022 走看看