zoukankan      html  css  js  c++  java
  • 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用激活函数

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    INPUT_NODE = 784     # 输入节点
    OUTPUT_NODE = 10     # 输出节点
    LAYER1_NODE = 500    # 隐藏层数       
                                  
    BATCH_SIZE = 100     # 每次batch打包的样本个数        
    
    # 模型相关的参数
    LEARNING_RATE_BASE = 0.01   
    LEARNING_RATE_DECAY = 0.99    
    REGULARAZTION_RATE = 0.0001   
    TRAINING_STEPS = 5000        
    MOVING_AVERAGE_DECAY = 0.99  
    
    def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
        # 不使用滑动平均类
        if avg_class == None:
            layer1 = tf.matmul(input_tensor, weights1) + biases1
            return tf.matmul(layer1, weights2) + biases2
        else:
            # 使用滑动平均类
            layer1 = tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1)
            return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)  
        
    def train(mnist):
        x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
        # 生成隐藏层的参数。
        weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
        biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
        # 生成输出层的参数。
        weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
        biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
    
        # 计算不含滑动平均类的前向传播结果
        y = inference(x, None, weights1, biases1, weights2, biases2)
        
        # 定义训练轮数及相关的滑动平均类 
        global_step = tf.Variable(0, trainable=False)
        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables())
        average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)
        
        # 计算交叉熵及其平均值
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
        cross_entropy_mean = tf.reduce_mean(cross_entropy)
        
        # 损失函数的计算
        regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
        regularaztion = regularizer(weights1) + regularizer(weights2)
        loss = cross_entropy_mean + regularaztion
        
        # 设置指数衰减的学习率。
        learning_rate = tf.train.exponential_decay(
            LEARNING_RATE_BASE,
            global_step,
            mnist.train.num_examples / BATCH_SIZE,
            LEARNING_RATE_DECAY,
            staircase=True)
        
        # 优化损失函数
        train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
        
        # 反向传播更新参数和更新每一个参数的滑动平均值
        with tf.control_dependencies([train_step, variables_averages_op]):
            train_op = tf.no_op(name='train')
    
        # 计算正确率
        correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
        # 初始化会话,并开始训练过程。
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
            test_feed = {x: mnist.test.images, y_: mnist.test.labels} 
            
            # 循环的训练神经网络。
            for i in range(TRAINING_STEPS):
                if i % 1000 == 0:
                    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
                    print("After %d training step(s), validation accuracy using average model is %g " % (i, validate_acc))
                xs,ys=mnist.train.next_batch(BATCH_SIZE)
                sess.run(train_op,feed_dict={x:xs,y_:ys})
            test_acc=sess.run(accuracy,feed_dict=test_feed)
            print(("After %d training step(s), test accuracy using average model is %g" %(TRAINING_STEPS, test_acc)))
            
    def main(argv=None):
        mnist = input_data.read_data_sets("E:\MNIST_data", one_hot=True)
        train(mnist)
    
    if __name__=='__main__':
        main()

  • 相关阅读:
    The Mac Application Environment 不及格的程序员
    Xcode Plugin: Change Code In Running App Without Restart 不及格的程序员
    The property delegate of CALayer cause Crash. 不及格的程序员
    nil localizedTitle in SKProduct 不及格的程序员
    InApp Purchase 不及格的程序员
    Safari Web Content Guide 不及格的程序员
    在Mac OS X Lion 安装 XCode 3.2 不及格的程序员
    illustrate ARC with graphs 不及格的程序员
    Viewing iPhoneOptimized PNGs 不及格的程序员
    What is the dSYM? 不及格的程序员
  • 原文地址:https://www.cnblogs.com/tszr/p/10875658.html
Copyright © 2011-2022 走看看