zoukankan      html  css  js  c++  java
  • softmax_regression完成mnist手写体数据集的识别

    ---恢复内容开始---

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    
    
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    
    # 根据y, y_构造交叉熵损失
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))
    
    # 有了损失,我们就可以用随机梯度下降针对模型的参数(W和b)进行优化
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    sess = tf.InteractiveSession() 
    
    tf.global_variables_initializer().run() 
    
    print('start training...') 
    
    # 进行1000步梯度下降 
    
    for i in range(5000):
      batch = mnist.train.next_batch(50)# 默认 shuffle=True
      if i%100==0:
        train_accuracy = accuracy.eval({x:batch[0], y_:batch[1]})
        print("step:%d accuracy is %g" %(i , train_accuracy) )
    
    sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})
    
    #print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 
    
    print(accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels}))

    ---恢复内容结束---

  • 相关阅读:
    [BZOJ1015] [JSOI2008]星球大战starwar
    [BZOJ2321,LuoguP1861]星(之)器
    Google Search Operators
    Python blockchain
    CCAE词频表(转)
    python小技巧(转)
    Python著名的lib和开发框架(均为转载)
    Yarn取代job/task tracker
    hadoop 2.73‘s four xml
    HDFS NN,SNN,BN和HA
  • 原文地址:https://www.cnblogs.com/the-wolf-sky/p/10395510.html
Copyright © 2011-2022 走看看