zoukankan      html  css  js  c++  java
  • 2.tensorflow——Softmax回归

    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    from tensorflow.examples.tutorials.mnist import input_data
    
    #download data
    mnist=input_data.read_data_sets('data/',one_hot=True)
    trainimg=mnist.train.images
    trainlabel=mnist.train.labels
    testimg=mnist.test.images
    
    print("downloading...")
    print("type:%s" % (type(mnist)))
    print("tain data size:%d" % (mnist.train.num_examples))
    print("test data size:%d" % (mnist.test.num_examples))
    print("tarin lable's shape: %s" % (trainlabel.shape,))
    
    #show example
    # nsample = 5
    # randidx=np.random.randint(trainimg.shape[0],size=nsample)
    # for i in randidx:
    #     cur_img=np.reshape(trainimg[i,:],(28,28))
    #     cur_label=np.argmax(trainlabel[i,:])
    #     plt.matshow(cur_img)
    #     print(""+str(i)+"th training data,"+"which label is:"+str(cur_label))
    #     plt.show()
    
    #batch
    batch_size=100
    batch_xs,batch_ys=mnist.train.next_batch(batch_size)#x-data,y-label
    
    ####start train
    #1.set up
    numClasses=10
    inputSize=784#28*28
    trainningIterations=50000#total steps
    batchSize=64#
    
    #2.model #64:x(1*784)*w(784*10)+b1(10)=y(1*10)
    X=tf.placeholder(tf.float32,shape=[None,inputSize])
    y=tf.placeholder(tf.float32,shape=[None,numClasses])
    
    #2.1 initial
    W1 = tf.Variable(tf.zeros([784,10]))
    B1 = tf.Variable(tf.zeros([10]))
    
    #2.2 model set
    y_pred=tf.nn.softmax(tf.matmul(X,W1)+B1)#10*1
    loss=tf.reduce_mean(tf.square(y-y_pred))
    cross_entropy=-tf.reduce_sum(y*tf.log(y_pred))
    opt=tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(cross_entropy)
    correct_prediction=tf.equal(tf.argmax(y_pred,1),tf.argmax(y,1))#
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float"))#bool 2 float
    
    #2.3 run train
    sess=tf.Session()
    init=tf.global_variables_initializer()
    sess.run(init)
    for i in range(trainningIterations):
        batch=mnist.train.next_batch(batch_size)
        batchInput=batch[0]
        batchLabels=batch[1]
        sess.run(opt,feed_dict={X:batchInput,y:batchLabels})
        if i%1000 == 0:
             train_accuracy=sess.run(accuracy,feed_dict={X:batchInput,y:batchLabels})
             print("step %d, tarinning accuracy %g" % (i,train_accuracy))
    
    #2.4 run test to accuracy
    batch=mnist.test.next_batch(batch_size)
    testAccuracy=sess.run(accuracy,feed_dict={X:batch[0],y:batch[1]})
    print("test accuracy %g" % (testAccuracy))

     理论参考:

    http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

  • 相关阅读:
    dpkg 删除 百度网盘 程序
    ubuntu 安装go
    解决 swap file “*.swp”already exists!问题
    ROS Topic 常用指令
    正交概念
    vim 永久显示行号 & 临时显示行号
    awk、grep、sed
    Keil中使用Astyel进行C语言的格式化
    红黑树学习
    802.11 对于multicast 和 broadcast的处理
  • 原文地址:https://www.cnblogs.com/yrm1160029237/p/11868697.html
Copyright © 2011-2022 走看看