zoukankan      html  css  js  c++  java
  • 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用滑动平均

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    INPUT_NODE = 784     # 输入节点
    OUTPUT_NODE = 10     # 输出节点
    LAYER1_NODE = 500    # 隐藏层数       
                                  
    BATCH_SIZE = 100     # 每次batch打包的样本个数        
    
    # 模型相关的参数
    LEARNING_RATE_BASE = 0.8      
    LEARNING_RATE_DECAY = 0.99    
    REGULARAZTION_RATE = 0.0001   
    TRAINING_STEPS = 5000        
    
    def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
        # 不使用滑动平均类
        if avg_class == None:
            layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
            return tf.matmul(layer1, weights2) + biases2
        else:
            # 使用滑动平均类
            layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))
            return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)  
        
    def train(mnist):
        x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
        # 生成隐藏层的参数。
        weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
        biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
        # 生成输出层的参数。
        weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
        biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
    
        # 计算不含滑动平均类的前向传播结果
        y = inference(x, None, weights1, biases1, weights2, biases2)
        
        # 定义训练轮数及相关的滑动平均类 
        global_step = tf.Variable(0, trainable=False)
        
        # 计算交叉熵及其平均值
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
        cross_entropy_mean = tf.reduce_mean(cross_entropy)
        
        # 损失函数的计算
        regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
        regularaztion = regularizer(weights1) + regularizer(weights2)
        loss = cross_entropy_mean + regularaztion
        
        # 设置指数衰减的学习率。
        learning_rate = tf.train.exponential_decay(
            LEARNING_RATE_BASE,
            global_step,
            mnist.train.num_examples / BATCH_SIZE,
            LEARNING_RATE_DECAY,
            staircase=True)
        
        # 优化损失函数
        train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
        
        # 反向传播更新参数
        with tf.control_dependencies([train_step]):
            train_op = tf.no_op(name='train')
    
        # 计算正确率
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
        # 初始化会话,并开始训练过程。
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
            test_feed = {x: mnist.test.images, y_: mnist.test.labels} 
            
            # 循环的训练神经网络。
            for i in range(TRAINING_STEPS):
                if i % 1000 == 0:
                    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
                    print("After %d training step(s), validation accuracy using average model is %g " % (i, validate_acc))
                xs,ys=mnist.train.next_batch(BATCH_SIZE)
                sess.run(train_op,feed_dict={x:xs,y_:ys})
            test_acc=sess.run(accuracy,feed_dict=test_feed)
            print(("After %d training step(s), test accuracy using average model is %g" %(TRAINING_STEPS, test_acc)))
            
    def main(argv=None):
        mnist = input_data.read_data_sets("E:\MNIST_data\", one_hot=True)
        train(mnist)
    
    if __name__=='__main__':
        main()

  • 相关阅读:
    iaas,paas,saas理解
    July 06th. 2018, Week 27th. Friday
    July 05th. 2018, Week 27th. Thursday
    July 04th. 2018, Week 27th. Wednesday
    July 03rd. 2018, Week 27th. Tuesday
    July 02nd. 2018, Week 27th. Monday
    July 01st. 2018, Week 27th. Sunday
    June 30th. 2018, Week 26th. Saturday
    June 29th. 2018, Week 26th. Friday
    June 28th. 2018, Week 26th. Thursday
  • 原文地址:https://www.cnblogs.com/tszr/p/10875919.html
Copyright © 2011-2022 走看看