zoukankan      html  css  js  c++  java
  • 004-2-拟合,drop-out

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    #载入数据
    mnist = input_data.read_data_sets("MNIST_data",one_hot = True)
    
    #定义每个批次的大小
    batch_size = 100
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples//batch_size
    
    #定义2个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob = tf.placeholder(tf.float32) #表示有百分之多少的神经元工作
    
    #神经网络:
    #正态分布,方差0.1
    W1 = tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([2000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop_out = tf.nn.dropout(L1,keep_prob)
    
    W2 = tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
    b2 = tf.Variable(tf.zeros([2000])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop_out,W2)+b2)
    L2_drop_out = tf.nn.dropout(L2,keep_prob)
    
    W3 = tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
    b3 = tf.Variable(tf.zeros([1000])+0.1)
    L3 = tf.nn.tanh(tf.matmul(L2_drop_out,W3)+b3)
    L3_drop_out = tf.nn.dropout(L3,keep_prob)
    
    W4 = tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
    b4 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L3_drop_out,W4)+b4)
    
    #二次代价函数:
    # loss = tf.reduce_mean(tf.square(y-prediction))
    #对数似然函数
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels= y,
                                                      logits= prediction)) 
    
    #梯度下降
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #求准确率
    #比较预测值最大标签位置与真实值最大标签位置是否相等
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    #求准去率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys = mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict = {x:batch_xs,
                                                 y:batch_ys,
                                                 keep_prob:0.7})
            test_acc = sess.run(accuracy,feed_dict ={x:mnist.test.images,
                                                     y:mnist.test.labels,
                                                     keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict ={x:mnist.train.images,
                                                      y:mnist.train.labels,
                                                      keep_prob:1.0})
            print("Iter"+str(epoch+1)+",Testing accuracy-"+str(test_acc)
                   +",Train accuracy-"+str(train_acc))
    

      

    keep_prob = 1.0时

    Iter1,Testing accuracy-0.944,Train accuracy-0.958545
    Iter2,Testing accuracy-0.958,Train accuracy-0.974691
    Iter3,Testing accuracy-0.9621,Train accuracy-0.982855
    Iter4,Testing accuracy-0.9652,Train accuracy-0.986455
    Iter5,Testing accuracy-0.968,Train accuracy-0.988364
    Iter6,Testing accuracy-0.9683,Train accuracy-0.989855
    Iter7,Testing accuracy-0.9694,Train accuracy-0.990982
    Iter8,Testing accuracy-0.9687,Train accuracy-0.991636
    Iter9,Testing accuracy-0.9691,Train accuracy-0.992255
    Iter10,Testing accuracy-0.9697,Train accuracy-0.9926
    Iter11,Testing accuracy-0.9697,Train accuracy-0.992909
    Iter12,Testing accuracy-0.9705,Train accuracy-0.993236
    Iter13,Testing accuracy-0.9701,Train accuracy-0.993309
    Iter14,Testing accuracy-0.9707,Train accuracy-0.993527
    Iter15,Testing accuracy-0.9705,Train accuracy-0.993691
    Iter16,Testing accuracy-0.9709,Train accuracy-0.993891
    Iter17,Testing accuracy-0.9707,Train accuracy-0.993982
    Iter18,Testing accuracy-0.9716,Train accuracy-0.994036
    Iter19,Testing accuracy-0.9717,Train accuracy-0.994236
    Iter20,Testing accuracy-0.9722,Train accuracy-0.994364
    Iter21,Testing accuracy-0.9716,Train accuracy-0.994436
    Iter22,Testing accuracy-0.972,Train accuracy-0.994509
    Iter23,Testing accuracy-0.9722,Train accuracy-0.9946
    Iter24,Testing accuracy-0.972,Train accuracy-0.994636
    Iter25,Testing accuracy-0.9723,Train accuracy-0.994709
    Iter26,Testing accuracy-0.9723,Train accuracy-0.994836
    Iter27,Testing accuracy-0.9722,Train accuracy-0.994891
    Iter28,Testing accuracy-0.9727,Train accuracy-0.994964
    Iter29,Testing accuracy-0.9724,Train accuracy-0.995091
    Iter30,Testing accuracy-0.9725,Train accuracy-0.995164
    Iter31,Testing accuracy-0.9725,Train accuracy-0.995182

    keep_prob = 0.7时
    Iter1,Testing accuracy-0.9187,Train accuracy-0.912709
    Iter2,Testing accuracy-0.9281,Train accuracy-0.923782
    Iter3,Testing accuracy-0.9357,Train accuracy-0.935236
    Iter4,Testing accuracy-0.9379,Train accuracy-0.940855
    Iter5,Testing accuracy-0.9441,Train accuracy-0.944564
    Iter6,Testing accuracy-0.9463,Train accuracy-0.948164
    Iter7,Testing accuracy-0.9472,Train accuracy-0.950182
    Iter8,Testing accuracy-0.9515,Train accuracy-0.9544
    Iter9,Testing accuracy-0.9548,Train accuracy-0.956455
    Iter10,Testing accuracy-0.9551,Train accuracy-0.959091
    Iter11,Testing accuracy-0.9566,Train accuracy-0.959891
    Iter12,Testing accuracy-0.9594,Train accuracy-0.962036
    Iter13,Testing accuracy-0.9592,Train accuracy-0.964236
    Iter14,Testing accuracy-0.9585,Train accuracy-0.964818
    Iter15,Testing accuracy-0.9607,Train accuracy-0.966
    Iter16,Testing accuracy-0.961,Train accuracy-0.9668
    Iter17,Testing accuracy-0.9612,Train accuracy-0.967891
    Iter18,Testing accuracy-0.9643,Train accuracy-0.969236
    Iter19,Testing accuracy-0.9646,Train accuracy-0.969945
    Iter20,Testing accuracy-0.9655,Train accuracy-0.970909
    Iter21,Testing accuracy-0.9656,Train accuracy-0.971509
    Iter22,Testing accuracy-0.9668,Train accuracy-0.972891
    Iter23,Testing accuracy-0.9665,Train accuracy-0.972982
    Iter24,Testing accuracy-0.9687,Train accuracy-0.974091
    Iter25,Testing accuracy-0.9673,Train accuracy-0.974782
    Iter26,Testing accuracy-0.9682,Train accuracy-0.975127
    Iter27,Testing accuracy-0.9682,Train accuracy-0.976055
    Iter28,Testing accuracy-0.9703,Train accuracy-0.976582
    Iter29,Testing accuracy-0.9692,Train accuracy-0.976982
    Iter30,Testing accuracy-0.9707,Train accuracy-0.977891
    Iter31,Testing accuracy-0.9703,Train accuracy-0.978091

     

    可以看到,当drop-out了30%的时候,训练集与测试集的准确率相差比全连接要小一些,可以防止过拟合的情况出现。

     

  • 相关阅读:
    CodeForces 681D Gifts by the List (树上DFS)
    UVa 12342 Tax Calculator (水题,纳税)
    CodeForces 681C Heap Operations (模拟题,优先队列)
    CodeForces 682C Alyona and the Tree (树上DFS)
    CodeForces 682B Alyona and Mex (题意水题)
    CodeForces 682A Alyona and Numbers (水题,数学)
    Virtualizing memory type
    页面跳转
    PHP Misc. 函数
    PHP 5 Math 函数
  • 原文地址:https://www.cnblogs.com/Mjerry/p/9828102.html
Copyright © 2011-2022 走看看