zoukankan      html  css  js  c++  java
  • 第二节,mnist手写字体识别

    1、获取mnist数据集,得到正确的数据格式

    mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

    2、定义网络大小:图片的大小是28*28,784个像素点,输入神经元为784个,输出0~9个数,输出神经元为10个

    n_input =784
    n_layer1 = 10
    examples_to_show = 10 #显示的测试图像个数
    x_data = tf.placeholder(tf.float32,[None,n_input])
    y_data = tf.placeholder(tf.float32,[None,n_layer1])
     
    3、添加层函数
    inputsize:输入神经元的个数weights:权重biases:偏置值activation_function:激活函数
    输出:该层的进行激活后的神经元,下一个层的输入
    def addlayer(inputsize,weights,biases,activation_function=None):
        output = tf.add(tf.matmul(x_data,weights),biases)
        if activation_function == None:
            return tf.nn.sigmoid(output)
        else:
            return tf.nn.softmax(output)
     
    4、构建网络
    #预测输出
    添加隐藏层
    y_pre=addlayer(x_data,layer1_weights,layer1_biases,activation_function=tf.nn.softmax)
    y_true=y_data
    #反向
    cross_entropy=-tf.reduce_sum(y_true * tf.log(y_pre))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    for i in range(1000):
         batch_xs,batch_ys =mnist.train.next_batch(100)
         sess.run(train_step,feed_dict={x_data:batch_xs , y_data:batch_ys})
         if (i%50==0):
            print ("loss : ",sess.run(cross_entropy,feed_dict={x_data:batch_xs , y_data:batch_ys}))
            # 这个越大越好
            print ("prediction acc : ",compute_acc(mnist.test.images[:100], mnist.test.labels[:100]))
    5、测试集计算模型预测准确率
    def compute_acc(x_input_test,y_true_test):
        y_pre_test = sess.run(y_pre,feed_dict={x_data:x_input_test, y_data:y_true_test})
        correct_prediction=tf.equal(tf.argmax(y_pre_test,1),tf.argmax(y_true_test,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        result = sess.run(accuracy, feed_dict={x_data:x_input_test , y_data:y_true_test})
        return result
    6、结果
  • 相关阅读:
    Anaconda+Tensorflow环境安装与配置
    计算机视觉(视频追踪检测分类、监控追踪)常用测试数据集
    迁移学习( Transfer Learning )
    matlab函数_连通区域
    GMM的EM算法实现
    对​O​p​e​n​C​V​直​方​图​的​数​据​结​构​C​v​H​i​s​t​o​g​r​a​m​的​理​解
    opencv基于混合高斯模型的图像分割
    LNK1123: 转换到 COFF 期间失败: 文件无效或损坏
    视频测试序列(转)
    高职扩招,拿大专学历
  • 原文地址:https://www.cnblogs.com/wyx501/p/10440299.html
Copyright © 2011-2022 走看看