zoukankan      html  css  js  c++  java
  • 8.Dropout

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 64
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义三个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    
    # 784-1000-500-10
    W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([1000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop = tf.nn.dropout(L1,keep_prob) 
    
    W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
    b2 = tf.Variable(tf.zeros([500])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop = tf.nn.dropout(L2,keep_prob) 
    
    W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
    
    #交叉熵
    loss = tf.losses.softmax_cross_entropy(y,prediction)
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5})
            
            test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547
    Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636
    Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182
    Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365
    Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273
    Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454
    Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909
    Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502
    Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366
    Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725
    Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726
    Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909
    Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182
    Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274
    Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546
    Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727
    Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094
    Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635
    Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909
    Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652
    Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274
    Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546
    Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364
    Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456
    Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091
    Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818
    Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702
    Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909
    Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818
    Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364
    Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454
  • 相关阅读:
    HDFS snapshot操作实战
    不是技术牛人,如何拿到国内IT巨头的Offer(转载)
    HBase的RowKey设计原则
    hbase shell 基本命令总结
    13_Python数据类型字符串加强_Python编程之路
    监督学习与无监督学习的区别_机器学习
    12_Python的(匿名函数)Lambda表达式_Python编程之路
    Python数据挖掘_Python2模块Spynner的安装(安装失败)
    06_Linux目录文件操作命令3查找命令_我的Linux之路
    python数据挖掘_Json结构分析
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11605482.html
Copyright © 2011-2022 走看看