zoukankan      html  css  js  c++  java
  • 8.Dropout

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 64
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义三个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    
    # 784-1000-500-10
    W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([1000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop = tf.nn.dropout(L1,keep_prob) 
    
    W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
    b2 = tf.Variable(tf.zeros([500])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop = tf.nn.dropout(L2,keep_prob) 
    
    W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
    
    #交叉熵
    loss = tf.losses.softmax_cross_entropy(y,prediction)
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5})
            
            test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547
    Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636
    Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182
    Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365
    Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273
    Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454
    Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909
    Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502
    Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366
    Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725
    Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726
    Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909
    Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182
    Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274
    Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546
    Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727
    Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094
    Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635
    Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909
    Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652
    Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274
    Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546
    Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364
    Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456
    Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091
    Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818
    Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702
    Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909
    Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818
    Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364
    Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454
  • 相关阅读:
    cocos2d-x 3.0rc开发指南:Windows下Android环境搭建
    CSS 最核心的几个概念
    android CMWAP, CMNET有何差别
    JAVA读、写EXCEL文件
    freemarker报错之三
    “大型票务系统”和“实物电商系统”在支付方面的差别和联系
    C++内存泄露检測原理
    递归算法浅谈
    自写图片遮罩层放大功能jquery插件源代码,photobox.js 1.0版,不兼容IE6
    linux 之 getopt_long()
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11605482.html
Copyright © 2011-2022 走看看