zoukankan      html  css  js  c++  java
  • 8.Dropout

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 64
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义三个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    
    # 784-1000-500-10
    W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([1000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop = tf.nn.dropout(L1,keep_prob) 
    
    W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
    b2 = tf.Variable(tf.zeros([500])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop = tf.nn.dropout(L2,keep_prob) 
    
    W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
    
    #交叉熵
    loss = tf.losses.softmax_cross_entropy(y,prediction)
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5})
            
            test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547
    Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636
    Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182
    Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365
    Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273
    Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454
    Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909
    Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502
    Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366
    Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725
    Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726
    Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909
    Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182
    Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274
    Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546
    Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727
    Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094
    Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635
    Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909
    Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652
    Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274
    Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546
    Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364
    Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456
    Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091
    Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818
    Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702
    Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909
    Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818
    Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364
    Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454
  • 相关阅读:
    C# 二维数组 排列组合
    highcharts(数据可视化框架),ajax传递数据问题
    EasyPoi导入验证功能
    EasyPoi使用入门
    SSJ(Spring+springMVC+JPA)设置xml文件思路流程
    spring框架设置jdbc
    使用JDBC完成CRUD(增删改查)
    Java的数据类型(常量,变量)
    jdk8的安装与卸载
    Java的第一个你好世界
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11605482.html
Copyright © 2011-2022 走看看