zoukankan      html  css  js  c++  java
  • 3 TensorFlow入门之识别手写数字

    ————————————————————————————————————

    写在开头:此文参照莫烦python教程(墙裂推荐!!!)

    ————————————————————————————————————

    分类实验之识别手写数字

    • 这个实验的内容是:基于TensorFlow,实现手写数字的识别。
    • 这里用到的数据集是大家熟知的mnist数据集。
    • mnist有五万多张手写数字的图片,每个图片用28x28的像素矩阵表示。所以我们的输入层每个案列的特征个数就有28x28=784个;因为数字有0,1,2…9共十个,所以我们的输出层是个1x10的向量。输出层是十个小于1的非负数,表示该预测是0,1,2…9的概率,我们选取最大概率所对应的数字作为我们的最终预测。
    • 真实的数字表示为该数字所对应的位置为1,其余位置为0的1x10的向量。

    下面就开始实验啦!

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    #导入数据
    mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#如果还没下载mnist就下载
    
    #定义添加层
    def add_layer(inputs,in_size,out_size,activation_function=None):
        #定义添加层内容,返回这层的outputs
        Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一个in_size行、out_size列的矩阵,开始时用随机数填满
        biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一个1行out_size列的矩阵,用0.1填满
        Wx_plus_b = tf.matmul(inputs,Weights)+biases  #预测
        if activation_function is None:  #如果没有激励函数,那么outputs就是预测值
            outputs = Wx_plus_b
        else:  #如果有激励函数,那么outputs就是激励函数作用于预测值之后的值
            outputs = activation_function(Wx_plus_b)
        return outputs
    
    #定义计算正确率的函数
    def t_accuracy(t_xs,t_ys):
        global prediction
        y_pre = sess.run(prediction,feed_dict={xs:t_xs})
        correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
        accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
        result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys})
        return result
    
    #定义神经网络的输入值和输出值
    xs = tf.placeholder(tf.float32,[None,784]) #None是不规定大小,这里指的是案例个数,而输入特征个数为28x28 = 784
    ys = tf.placeholder(tf.float32,[None,10]) #Nnoe也是案例个数,不做规定;10是因为有10个数字,所以输出是10
    
    #增加输出层
    prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)#这里的激励函数是softmax,此函数多用于多类分类
    
    #计算误差
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #此误差计算方式和softmax配套用,效果好
    
    #训练
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#学习因子为0.5
    
    #开始训练
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    
    for i in range(1000):
        batch_xs,batch_ys = mnist.train.next_batch(100)   #提取数据集的100个数据,因为原来数据太大了
        sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
        if i%50 == 0:
            print (t_accuracy(mnist.test.images,mnist.test.labels))  #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
    
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    0.1849
    0.6537
    0.7393
    0.7836
    0.8053
    0.8203
    0.8275
    0.837
    0.8465
    0.8504
    0.8567
    0.8571
    0.8643
    0.8637
    0.8664
    0.8687
    0.8719
    0.8742
    0.8763
    0.8773
    

    上面4行就是下载的mnist数据集的四个文件。然后看打印出来的正确率可知,这个网络的预测能力是越来越好的。
    下面试一下啊,抽取500个数据来训练,看看效果如何:

    for i in range(1000):
        batch_xs,batch_ys = mnist.train.next_batch(500)   #提取数据集的500个数据,因为原来数据太大了
        sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
        if i%50 == 0:
            print (t_accuracy(mnist.test.images,mnist.test.labels))  #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
    
    0.9001
    0.9022
    0.9023
    0.9026
    0.903
    0.903
    0.9037
    0.9036
    0.9034
    0.9027
    0.9041
    0.903
    0.9039
    0.9034
    0.9037
    0.9046
    0.9055
    0.9045
    0.9053
    0.905
    

    由上面打印出来的正确率可知,抽取500个数据来训练的话,正确率会达到90%


    *点击[这儿:TensorFlow]发现更多关于TensorFlow的文章*


  • 相关阅读:
    mac Navicat连接Oracle报错ORA-21561: OID generation failed
    svn: E230001: Server SSL certificate verification failed: certificate issued
    mac删除系统应用出现mac Read-Only filesystem
    spring boot项目03:阅读启动源码
    spring boot项目02:Web项目(基础)
    spring boot项目01:非Web项目(基础)
    idea 单独引入jar_Iidea 单独引入jar_Intellij IDEA 添加jar包的三种方式ntellij IDEA 添加jar包的三种方式
    java输出pdf的依赖包,非maven,包名:spire.pdf.jar 下载
    IDEA Error:java: 无效的源发行版: 11错误
    SpringBoot官网以下载模板方式创建
  • 原文地址:https://www.cnblogs.com/surecheun/p/9648968.html
Copyright © 2011-2022 走看看