zoukankan      html  css  js  c++  java
  • 手工设计神经MNIST使分类精度达到98%以上

    设计了两个隐藏层,激活函数是tanh,使用Adam优化算法,学习率随着epoch的增大而调低

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 32
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    lr = tf.Variable(0.001, dtype=tf.float32)
    
    #创建一个简单的神经网络
    W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1))
    b1 = tf.Variable(tf.zeros([500])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop = tf.nn.dropout(L1,keep_prob) 
    
    W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1))
    b2 = tf.Variable(tf.zeros([300])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop = tf.nn.dropout(L2,keep_prob) 
    
    W3 = tf.Variable(tf.truncated_normal([300,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
    
    #交叉熵代价函数
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
    #训练
    train_step = tf.train.AdamOptimizer(lr).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(51):
            sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch)))
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:1.0})
            
            learning_rate = sess.run(lr)
            acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            print ("Iter " + str(epoch) + ", Testing Accuracy= " + str(acc) + ", Learning Rate= " + str(learning_rate))


    #
    Iter 0, Testing Accuracy= 0.954, Learning Rate= 0.001
    Iter 1, Testing Accuracy= 0.9624, Learning Rate= 0.00095
    Iter 2, Testing Accuracy= 0.9668, Learning Rate= 0.0009025
    Iter 3, Testing Accuracy= 0.9665, Learning Rate= 0.000857375
    Iter 4, Testing Accuracy= 0.9725, Learning Rate= 0.00081450626
    Iter 5, Testing Accuracy= 0.9738, Learning Rate= 0.0007737809
    Iter 6, Testing Accuracy= 0.9769, Learning Rate= 0.0007350919
    Iter 7, Testing Accuracy= 0.9771, Learning Rate= 0.0006983373
    Iter 8, Testing Accuracy= 0.9777, Learning Rate= 0.0006634204
    Iter 9, Testing Accuracy= 0.9764, Learning Rate= 0.0006302494
    Iter 10, Testing Accuracy= 0.9753, Learning Rate= 0.0005987369
    Iter 11, Testing Accuracy= 0.9779, Learning Rate= 0.0005688001
    Iter 12, Testing Accuracy= 0.9777, Learning Rate= 0.0005403601
    Iter 13, Testing Accuracy= 0.9774, Learning Rate= 0.0005133421
    Iter 14, Testing Accuracy= 0.9772, Learning Rate= 0.000487675
    Iter 15, Testing Accuracy= 0.9803, Learning Rate= 0.00046329122
    Iter 16, Testing Accuracy= 0.9802, Learning Rate= 0.00044012666
    Iter 17, Testing Accuracy= 0.9791, Learning Rate= 0.00041812033
    Iter 18, Testing Accuracy= 0.9806, Learning Rate= 0.00039721432
    Iter 19, Testing Accuracy= 0.9803, Learning Rate= 0.0003773536
    Iter 20, Testing Accuracy= 0.9796, Learning Rate= 0.00035848594
    Iter 21, Testing Accuracy= 0.9803, Learning Rate= 0.00034056162
    Iter 22, Testing Accuracy= 0.9788, Learning Rate= 0.00032353355
    Iter 23, Testing Accuracy= 0.9819, Learning Rate= 0.00030735688
    Iter 24, Testing Accuracy= 0.975, Learning Rate= 0.000291989
    Iter 25, Testing Accuracy= 0.9808, Learning Rate= 0.00027738957
    Iter 26, Testing Accuracy= 0.9814, Learning Rate= 0.0002635201
    Iter 27, Testing Accuracy= 0.9802, Learning Rate= 0.00025034408
    Iter 28, Testing Accuracy= 0.9809, Learning Rate= 0.00023782688
    Iter 29, Testing Accuracy= 0.9811, Learning Rate= 0.00022593554
    Iter 30, Testing Accuracy= 0.9816, Learning Rate= 0.00021463877
    Iter 31, Testing Accuracy= 0.9812, Learning Rate= 0.00020390682
    Iter 32, Testing Accuracy= 0.9815, Learning Rate= 0.00019371149
    Iter 33, Testing Accuracy= 0.9815, Learning Rate= 0.0001840259
    Iter 34, Testing Accuracy= 0.9813, Learning Rate= 0.00017482461
    Iter 35, Testing Accuracy= 0.981, Learning Rate= 0.00016608338
    Iter 36, Testing Accuracy= 0.9806, Learning Rate= 0.00015777921
    Iter 37, Testing Accuracy= 0.9818, Learning Rate= 0.00014989026
    Iter 38, Testing Accuracy= 0.982, Learning Rate= 0.00014239574
    Iter 39, Testing Accuracy= 0.9813, Learning Rate= 0.00013527596
    Iter 40, Testing Accuracy= 0.9818, Learning Rate= 0.00012851215
    Iter 41, Testing Accuracy= 0.9827, Learning Rate= 0.00012208655
    Iter 42, Testing Accuracy= 0.9826, Learning Rate= 0.00011598222
    Iter 43, Testing Accuracy= 0.9814, Learning Rate= 0.00011018311
    Iter 44, Testing Accuracy= 0.9823, Learning Rate= 0.000104673956
    Iter 45, Testing Accuracy= 0.9828, Learning Rate= 9.944026e-05
    Iter 46, Testing Accuracy= 0.9824, Learning Rate= 9.446825e-05
    Iter 47, Testing Accuracy= 0.9824, Learning Rate= 8.974483e-05
    Iter 48, Testing Accuracy= 0.983, Learning Rate= 8.525759e-05
    Iter 49, Testing Accuracy= 0.9827, Learning Rate= 8.099471e-05
    Iter 50, Testing Accuracy= 0.9828, Learning Rate= 7.6944976e-05

    最终达到了0.9828的准确率

    人生苦短,何不用python
  • 相关阅读:
    Server Tomcat v8.0 Server at localhost was unable to start within 45 seconds. If the server requires more time, try increasing the timeout in the server editor.
    用户画像——“打标签”
    python replace函数替换无效问题
    python向mysql插入数据一直报TypeError: must be real number,not str
    《亿级用户下的新浪微博平台架构》读后感
    【2-10】标准 2 维表问题
    【2-8】集合划分问题(给定要分成几个集合)
    【2-7】集合划分问题
    【2-6】排列的字典序问题
    【2-5】有重复元素的排列问题
  • 原文地址:https://www.cnblogs.com/yqpy/p/11163147.html
Copyright © 2011-2022 走看看