zoukankan      html  css  js  c++  java
  • 关于神经网络优化的一些理解

    1、相关概念

    • 激活函数。由于在现实生活中,很多分类问题往往倾向于非线性,因此在处理这类问题时,我们根据传统的线性分类并不能达到一个满意的结果,这时候就需要使用激活函数去线性化,这样,就将问题转换成了非线性划分问题。如下图就是一个非线性问题。

      (1)常见的线性函数

    • 损失函数。刻画预测值与真实值的差异,常见的方法是计算交叉熵(交叉熵是满足概率分布的运算)。

    """ 计算交叉熵--损失函数 """
    # 计算交叉熵作为刻画预测值与真实值之间的差距的损失函数
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    # 计算当前样例的交叉熵平均值
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    
    • 正则化。通过限制权重的大小,防止模型过拟合
      (1)正则化有两种方式,L1和L2
    """ 正则化损失函数 """
    # 计算L2正则化损失函数
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularization = regularizer(weights1) + regularizer(weights2)
    loss = cross_entropy_mean + regularization
    
    • 反向传播和梯度下降
    • 学习率。梯度下降的幅度。
    """ 通过设置指数衰减的学习率优化,定义反向传播,进而优化正则化损失"""
    # 设置指数衰减的学习率
     learning_rate = tf.train.exponential_decay(LEARN_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARN_RATE_DECAY)
    # 定义反向传播,进而优化正则化损失(GradientDescentOptimizer表示反向传播)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    
    • 滑动平均。提高模型的健壮性。滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某次的异常取值而使得滑动平均值波动很大。
    """ 滑动平均优化 """
    # 给定滑动平均衰减率和训练轮数变量,初始化滑动平均类
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    # 在所有神经网络的参数上使用滑动平滑
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    

    2、神经网络的的5大优化

    • 5大优化
      (1)设计损失函数
      (2)正则化
      (3)反向传播和梯度下降
      (4)通过指数衰减设置学习率
      (5)滑动平均
    • 联系图(如下图所示)
      在前向传播中:我们会设计损失函数,如计算交叉熵或平均交叉熵等,为了防止过拟合,采用正则化;
      在反向传播中,我们设置学习率,定义反向传播,并通过反向传播优化损失;为了增强模型的健壮性,我们对参数变量(如权重,偏置项等)采用滑动平均模型。

    3、核心代码(使用《实战Google深度学习框架》代码)

    mnist_inference.py

    import tensorflow as tf
    
    # 定义神经网络相关的参数
    INPUT_NODE = 784
    OUTPUT_NODE = 10
    LAYER1_NODE = 500
    
    
    # 获取变量
    def get_weight_variable(shape, regularizer):
        weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None:
            # 损失函数正则化
            tf.add_to_collection("losses", regularizer(weights))
        return weights
    
    
    # 定义神经网络的前向传播过程
    def inference(input_tensor, regularizer):
        # 声明第一层神经网络的变量并完成前向传播过程
        with tf.variable_scope("layer1"):
            weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
            biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
            layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
    
        # 声明第二层神经网络的变量并完成前向传播过程
        with tf.variable_scope("layer2"):
            weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
            biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
            layer2 = tf.matmul(layer1, weights) + biases
    
        # 返回前向传播的结果
        return layer2
    

    mnist_train.py

    import os
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    import mnist_inference
    
    # 配置神经网络的参数
    BATCH_SIZE = 100
    LEARN_RATE_BASE = 0.8
    LEARN_RATE_DECAY = 0.99
    REGULARIZATION_RATE = 0.0001
    TRAINING_STEPS = 10000
    MOVING_AVERAGE_DECAY = 0.99
    
    # 模型保存的路径和文件名
    MODEL_SAVE_PATH = "model2/"
    MODEL_NAME = "model.ckpt"
    
    def train(mnist):
        # 定义输入输出placehoder
        x = tf.placeholder(dtype=tf.float32, shape=[None, mnist_inference.INPUT_NODE], name="x-input")
        y_ = tf.placeholder(dtype=tf.float32, shape=[None, mnist_inference.OUTPUT_NODE], name="y-input")
        # 定义存储训练轮数的变量
        global_step = tf.Variable(0, trainable=False)
    
        """ 计算神经网络前向传播,损失函数,并正则化损失函数 """
        regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
        y = mnist_inference.inference(x, regularizer)
        # 计算交叉熵作为刻画预测值与真实值之间的差距的损失函数
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
        # 计算当前batch中所有样例的交叉熵平均值
        cross_entropy_mean = tf.reduce_mean(cross_entropy)
        loss = cross_entropy_mean + tf.add_n(tf.get_collection("losses"))
    
        """ 通过设置指数衰减的学习率优化,定义反向传播,进而优化正则化损失"""
        # 设置指数衰减的学习率
        learning_rate = tf.train.exponential_decay(LEARN_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARN_RATE_DECAY)
        # 定义反向传播,进而优化正则化损失(GradientDescentOptimizer表示反向传播)
        train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    
        """ 滑动平均优化 """
        # 给定滑动平均衰减率和训练轮数变量,初始化滑动平均类
        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
        # 在所有神经网络的参数上使用滑动平滑
        variables_averages_op = variable_averages.apply(tf.trainable_variables())
    
        with tf.control_dependencies([train_step, variables_averages_op]):
            train_op = tf.no_op(name="train")
    
        # 初始化TensorFlow持久化类
        saver = tf.train.Saver()
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
    
    
            for i in range(TRAINING_STEPS):
                xs, ys = mnist.train.next_batch(BATCH_SIZE)
                _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
    
                if i% 1000 == 0:
                    print("After %d training step(s), loass on training batch is %g " % (step, loss_value))
                    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
    
    
    # 主程序入口
    def main(argv=None):
        mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
        train(mnist)
    
    if __name__ == '__main__':
        tf.app.run()
    
  • 相关阅读:
    权限框架之Shiro详解(非原创)
    MySQL数据库基础详解(非原创)
    ssm(Spring、Springmvc、Mybatis)实战之淘淘商城-第十四天(非原创)
    ssm(Spring、Springmvc、Mybatis)实战之淘淘商城-第十三天(非原创)
    nginx配置location与rewrite规则教程
    CentOS7安装MySQL 5.7
    MySQL 5.6 解决InnoDB: Error: Table "mysql"."innodb_table_stats" not found.问题
    公文流转系统(未完成)
    对java异常的总结及java项目中的常用的异常处理情况
    课堂动手动脑验证以及自定义异常类实现对异常处理——java异常类
  • 原文地址:https://www.cnblogs.com/komean/p/10730880.html
Copyright © 2011-2022 走看看