zoukankan      html  css  js  c++  java
  • 8.Dropout

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 64
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义三个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    
    # 784-1000-500-10
    W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([1000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop = tf.nn.dropout(L1,keep_prob) 
    
    W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
    b2 = tf.Variable(tf.zeros([500])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop = tf.nn.dropout(L2,keep_prob) 
    
    W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
    
    #交叉熵
    loss = tf.losses.softmax_cross_entropy(y,prediction)
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.5})
            
            test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    Iter 0,Testing Accuracy 0.9201,Training Accuracy 0.91234547
    Iter 1,Testing Accuracy 0.9256,Training Accuracy 0.9229636
    Iter 2,Testing Accuracy 0.9359,Training Accuracy 0.9328182
    Iter 3,Testing Accuracy 0.9375,Training Accuracy 0.93716365
    Iter 4,Testing Accuracy 0.9408,Training Accuracy 0.9411273
    Iter 5,Testing Accuracy 0.9407,Training Accuracy 0.94365454
    Iter 6,Testing Accuracy 0.9472,Training Accuracy 0.9484909
    Iter 7,Testing Accuracy 0.9472,Training Accuracy 0.9502
    Iter 8,Testing Accuracy 0.9516,Training Accuracy 0.95336366
    Iter 9,Testing Accuracy 0.9522,Training Accuracy 0.95552725
    Iter 10,Testing Accuracy 0.9525,Training Accuracy 0.95632726
    Iter 11,Testing Accuracy 0.9566,Training Accuracy 0.9578909
    Iter 12,Testing Accuracy 0.9574,Training Accuracy 0.9606182
    Iter 13,Testing Accuracy 0.9573,Training Accuracy 0.96107274
    Iter 14,Testing Accuracy 0.9587,Training Accuracy 0.9614546
    Iter 15,Testing Accuracy 0.9581,Training Accuracy 0.9616727
    Iter 16,Testing Accuracy 0.9599,Training Accuracy 0.96369094
    Iter 17,Testing Accuracy 0.9601,Training Accuracy 0.96403635
    Iter 18,Testing Accuracy 0.9618,Training Accuracy 0.9658909
    Iter 19,Testing Accuracy 0.9608,Training Accuracy 0.9652
    Iter 20,Testing Accuracy 0.9618,Training Accuracy 0.96607274
    Iter 21,Testing Accuracy 0.9634,Training Accuracy 0.96794546
    Iter 22,Testing Accuracy 0.9639,Training Accuracy 0.96836364
    Iter 23,Testing Accuracy 0.964,Training Accuracy 0.96965456
    Iter 24,Testing Accuracy 0.9644,Training Accuracy 0.9693091
    Iter 25,Testing Accuracy 0.9647,Training Accuracy 0.9703818
    Iter 26,Testing Accuracy 0.9639,Training Accuracy 0.9702
    Iter 27,Testing Accuracy 0.9651,Training Accuracy 0.9708909
    Iter 28,Testing Accuracy 0.9666,Training Accuracy 0.9711818
    Iter 29,Testing Accuracy 0.9644,Training Accuracy 0.9710364
    Iter 30,Testing Accuracy 0.9659,Training Accuracy 0.97205454
  • 相关阅读:
    设置QtCreator多核编译
    ZeroMQ研究与应用分析及学习资料
    彻底卸载Visual Studio 2013、Visual Studio 2015
    delphi 动态设置和访问cxgrid列的Properties
    delphi 拷贝文件时有进度显示
    Delphi 连接mysql的功能,去除乱码, 需要设置字符集
    cxGrid1 的使用方法
    Django day12 分页器
    Django day11(一) ajax 文件上传 提交json格式数据
    Django day08 多表操作 (五) 常用和非常用用字段
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11605482.html
Copyright © 2011-2022 走看看