zoukankan      html  css  js  c++  java
  • 004-2-拟合,drop-out

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    #载入数据
    mnist = input_data.read_data_sets("MNIST_data",one_hot = True)
    
    #定义每个批次的大小
    batch_size = 100
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples//batch_size
    
    #定义2个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob = tf.placeholder(tf.float32) #表示有百分之多少的神经元工作
    
    #神经网络:
    #正态分布,方差0.1
    W1 = tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([2000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop_out = tf.nn.dropout(L1,keep_prob)
    
    W2 = tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
    b2 = tf.Variable(tf.zeros([2000])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop_out,W2)+b2)
    L2_drop_out = tf.nn.dropout(L2,keep_prob)
    
    W3 = tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
    b3 = tf.Variable(tf.zeros([1000])+0.1)
    L3 = tf.nn.tanh(tf.matmul(L2_drop_out,W3)+b3)
    L3_drop_out = tf.nn.dropout(L3,keep_prob)
    
    W4 = tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
    b4 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L3_drop_out,W4)+b4)
    
    #二次代价函数:
    # loss = tf.reduce_mean(tf.square(y-prediction))
    #对数似然函数
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels= y,
                                                      logits= prediction)) 
    
    #梯度下降
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #求准确率
    #比较预测值最大标签位置与真实值最大标签位置是否相等
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    #求准去率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys = mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict = {x:batch_xs,
                                                 y:batch_ys,
                                                 keep_prob:0.7})
            test_acc = sess.run(accuracy,feed_dict ={x:mnist.test.images,
                                                     y:mnist.test.labels,
                                                     keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict ={x:mnist.train.images,
                                                      y:mnist.train.labels,
                                                      keep_prob:1.0})
            print("Iter"+str(epoch+1)+",Testing accuracy-"+str(test_acc)
                   +",Train accuracy-"+str(train_acc))
    

      

    keep_prob = 1.0时

    Iter1,Testing accuracy-0.944,Train accuracy-0.958545
    Iter2,Testing accuracy-0.958,Train accuracy-0.974691
    Iter3,Testing accuracy-0.9621,Train accuracy-0.982855
    Iter4,Testing accuracy-0.9652,Train accuracy-0.986455
    Iter5,Testing accuracy-0.968,Train accuracy-0.988364
    Iter6,Testing accuracy-0.9683,Train accuracy-0.989855
    Iter7,Testing accuracy-0.9694,Train accuracy-0.990982
    Iter8,Testing accuracy-0.9687,Train accuracy-0.991636
    Iter9,Testing accuracy-0.9691,Train accuracy-0.992255
    Iter10,Testing accuracy-0.9697,Train accuracy-0.9926
    Iter11,Testing accuracy-0.9697,Train accuracy-0.992909
    Iter12,Testing accuracy-0.9705,Train accuracy-0.993236
    Iter13,Testing accuracy-0.9701,Train accuracy-0.993309
    Iter14,Testing accuracy-0.9707,Train accuracy-0.993527
    Iter15,Testing accuracy-0.9705,Train accuracy-0.993691
    Iter16,Testing accuracy-0.9709,Train accuracy-0.993891
    Iter17,Testing accuracy-0.9707,Train accuracy-0.993982
    Iter18,Testing accuracy-0.9716,Train accuracy-0.994036
    Iter19,Testing accuracy-0.9717,Train accuracy-0.994236
    Iter20,Testing accuracy-0.9722,Train accuracy-0.994364
    Iter21,Testing accuracy-0.9716,Train accuracy-0.994436
    Iter22,Testing accuracy-0.972,Train accuracy-0.994509
    Iter23,Testing accuracy-0.9722,Train accuracy-0.9946
    Iter24,Testing accuracy-0.972,Train accuracy-0.994636
    Iter25,Testing accuracy-0.9723,Train accuracy-0.994709
    Iter26,Testing accuracy-0.9723,Train accuracy-0.994836
    Iter27,Testing accuracy-0.9722,Train accuracy-0.994891
    Iter28,Testing accuracy-0.9727,Train accuracy-0.994964
    Iter29,Testing accuracy-0.9724,Train accuracy-0.995091
    Iter30,Testing accuracy-0.9725,Train accuracy-0.995164
    Iter31,Testing accuracy-0.9725,Train accuracy-0.995182

    keep_prob = 0.7时
    Iter1,Testing accuracy-0.9187,Train accuracy-0.912709
    Iter2,Testing accuracy-0.9281,Train accuracy-0.923782
    Iter3,Testing accuracy-0.9357,Train accuracy-0.935236
    Iter4,Testing accuracy-0.9379,Train accuracy-0.940855
    Iter5,Testing accuracy-0.9441,Train accuracy-0.944564
    Iter6,Testing accuracy-0.9463,Train accuracy-0.948164
    Iter7,Testing accuracy-0.9472,Train accuracy-0.950182
    Iter8,Testing accuracy-0.9515,Train accuracy-0.9544
    Iter9,Testing accuracy-0.9548,Train accuracy-0.956455
    Iter10,Testing accuracy-0.9551,Train accuracy-0.959091
    Iter11,Testing accuracy-0.9566,Train accuracy-0.959891
    Iter12,Testing accuracy-0.9594,Train accuracy-0.962036
    Iter13,Testing accuracy-0.9592,Train accuracy-0.964236
    Iter14,Testing accuracy-0.9585,Train accuracy-0.964818
    Iter15,Testing accuracy-0.9607,Train accuracy-0.966
    Iter16,Testing accuracy-0.961,Train accuracy-0.9668
    Iter17,Testing accuracy-0.9612,Train accuracy-0.967891
    Iter18,Testing accuracy-0.9643,Train accuracy-0.969236
    Iter19,Testing accuracy-0.9646,Train accuracy-0.969945
    Iter20,Testing accuracy-0.9655,Train accuracy-0.970909
    Iter21,Testing accuracy-0.9656,Train accuracy-0.971509
    Iter22,Testing accuracy-0.9668,Train accuracy-0.972891
    Iter23,Testing accuracy-0.9665,Train accuracy-0.972982
    Iter24,Testing accuracy-0.9687,Train accuracy-0.974091
    Iter25,Testing accuracy-0.9673,Train accuracy-0.974782
    Iter26,Testing accuracy-0.9682,Train accuracy-0.975127
    Iter27,Testing accuracy-0.9682,Train accuracy-0.976055
    Iter28,Testing accuracy-0.9703,Train accuracy-0.976582
    Iter29,Testing accuracy-0.9692,Train accuracy-0.976982
    Iter30,Testing accuracy-0.9707,Train accuracy-0.977891
    Iter31,Testing accuracy-0.9703,Train accuracy-0.978091

     

    可以看到,当drop-out了30%的时候,训练集与测试集的准确率相差比全连接要小一些,可以防止过拟合的情况出现。

     

  • 相关阅读:
    python-杂烩
    24 Python 对象进阶
    23 Python 面向对象
    22 Python 模块与包
    21 Python 异常处理
    20 Python 常用模块
    18 Python 模块引入
    2 Python 基本语法
    1 Python 环境搭建
    3 Python os 文件和目录
  • 原文地址:https://www.cnblogs.com/Mjerry/p/9828102.html
Copyright © 2011-2022 走看看