zoukankan      html  css  js  c++  java
  • 9.正则化

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #载入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
    
    #每个批次的大小
    batch_size = 64
    #计算一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob=tf.placeholder(tf.float32)
    
    # 784-1000-500-10
    #创建一个简单的神经网络
    W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
    b1 = tf.Variable(tf.zeros([1000])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    L1_drop = tf.nn.dropout(L1,keep_prob) 
    
    W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
    b2 = tf.Variable(tf.zeros([500])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
    L2_drop = tf.nn.dropout(L2,keep_prob) 
    
    W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
    
    #正则项
    l2_loss = tf.nn.l2_loss(W1) + tf.nn.l2_loss(b1) + tf.nn.l2_loss(W2) + tf.nn.l2_loss(b2) + tf.nn.l2_loss(W3) + tf.nn.l2_loss(b3)
    
    #交叉熵
    loss = tf.losses.softmax_cross_entropy(y,prediction) + 0.0005*l2_loss
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    
    #初始化变量
    init = tf.global_variables_initializer()
    
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(31):
            for batch in range(n_batch):
                batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
                sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:1.0})
            
            test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
            train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))
    Iter 0,Testing Accuracy 0.9451,Training Accuracy 0.94643635
    Iter 1,Testing Accuracy 0.9529,Training Accuracy 0.9566909
    Iter 2,Testing Accuracy 0.96,Training Accuracy 0.96574545
    Iter 3,Testing Accuracy 0.9608,Training Accuracy 0.9655455
    Iter 4,Testing Accuracy 0.9644,Training Accuracy 0.96776366
    Iter 5,Testing Accuracy 0.9644,Training Accuracy 0.96772724
    Iter 6,Testing Accuracy 0.9612,Training Accuracy 0.9637455
    Iter 7,Testing Accuracy 0.9647,Training Accuracy 0.96952724
    Iter 8,Testing Accuracy 0.9635,Training Accuracy 0.9685091
    Iter 9,Testing Accuracy 0.9655,Training Accuracy 0.97016364
    Iter 10,Testing Accuracy 0.9631,Training Accuracy 0.96703637
    Iter 11,Testing Accuracy 0.9649,Training Accuracy 0.96965456
    Iter 12,Testing Accuracy 0.9673,Training Accuracy 0.9712909
    Iter 13,Testing Accuracy 0.9669,Training Accuracy 0.97174543
    Iter 14,Testing Accuracy 0.9644,Training Accuracy 0.9681818
    Iter 15,Testing Accuracy 0.9657,Training Accuracy 0.9709273
    Iter 16,Testing Accuracy 0.9655,Training Accuracy 0.97154546
    Iter 17,Testing Accuracy 0.966,Training Accuracy 0.9701818
    Iter 18,Testing Accuracy 0.9635,Training Accuracy 0.96852726
    Iter 19,Testing Accuracy 0.9665,Training Accuracy 0.9719818
    Iter 20,Testing Accuracy 0.9679,Training Accuracy 0.9732909
    Iter 21,Testing Accuracy 0.9683,Training Accuracy 0.9747273
    Iter 22,Testing Accuracy 0.9664,Training Accuracy 0.9724
    Iter 23,Testing Accuracy 0.9684,Training Accuracy 0.97367275
    Iter 24,Testing Accuracy 0.9666,Training Accuracy 0.9719091
    Iter 25,Testing Accuracy 0.9655,Training Accuracy 0.97212726
    Iter 26,Testing Accuracy 0.9682,Training Accuracy 0.9728
    Iter 27,Testing Accuracy 0.9676,Training Accuracy 0.97221816
    Iter 28,Testing Accuracy 0.9669,Training Accuracy 0.97238183
    Iter 29,Testing Accuracy 0.9675,Training Accuracy 0.97327274
    Iter 30,Testing Accuracy 0.9665,Training Accuracy 0.9725091
     
  • 相关阅读:
    用TortoiseSVN忽略文件或文件夹(ignore)(网络摘抄记录)
    GridView解决同一行item的高度不一样,如何同一行统一高度问题?
    解决android studio引用远程仓库下载慢(转)
    Databinding在自定义ViewGroup中如何绑定view
    (转 )【Android那些高逼格的写法】InvocationHandler与代理模式
    (转)秒懂,Java 注解 (Annotation)你可以这样学
    View拖拽 自定义绑定view拖拽的工具类
    bcrypt对密码加密的一些认识(学习笔记)
    Node.js+Koa开发微信公众号个人笔记(三)响应文本
    Node.js+Koa开发微信公众号个人笔记(二)响应事件
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11605487.html
Copyright © 2011-2022 走看看