zoukankan      html  css  js  c++  java
  • 第二节,mnist手写字体识别

    1、获取mnist数据集,得到正确的数据格式

    mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

    2、定义网络大小:图片的大小是28*28,784个像素点,输入神经元为784个,输出0~9个数,输出神经元为10个

    n_input =784
    n_layer1 = 10
    examples_to_show = 10 #显示的测试图像个数
    x_data = tf.placeholder(tf.float32,[None,n_input])
    y_data = tf.placeholder(tf.float32,[None,n_layer1])
     
    3、添加层函数
    inputsize:输入神经元的个数weights:权重biases:偏置值activation_function:激活函数
    输出:该层的进行激活后的神经元,下一个层的输入
    def addlayer(inputsize,weights,biases,activation_function=None):
        output = tf.add(tf.matmul(x_data,weights),biases)
        if activation_function == None:
            return tf.nn.sigmoid(output)
        else:
            return tf.nn.softmax(output)
     
    4、构建网络
    #预测输出
    添加隐藏层
    y_pre=addlayer(x_data,layer1_weights,layer1_biases,activation_function=tf.nn.softmax)
    y_true=y_data
    #反向
    cross_entropy=-tf.reduce_sum(y_true * tf.log(y_pre))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    for i in range(1000):
         batch_xs,batch_ys =mnist.train.next_batch(100)
         sess.run(train_step,feed_dict={x_data:batch_xs , y_data:batch_ys})
         if (i%50==0):
            print ("loss : ",sess.run(cross_entropy,feed_dict={x_data:batch_xs , y_data:batch_ys}))
            # 这个越大越好
            print ("prediction acc : ",compute_acc(mnist.test.images[:100], mnist.test.labels[:100]))
    5、测试集计算模型预测准确率
    def compute_acc(x_input_test,y_true_test):
        y_pre_test = sess.run(y_pre,feed_dict={x_data:x_input_test, y_data:y_true_test})
        correct_prediction=tf.equal(tf.argmax(y_pre_test,1),tf.argmax(y_true_test,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        result = sess.run(accuracy, feed_dict={x_data:x_input_test , y_data:y_true_test})
        return result
    6、结果
  • 相关阅读:
    c++下使用邮槽实现进程间通信
    c++下基于windows socket的多线程服务器(基于TCP协议)
    C++实现线程同步的几种方式
    c++多线程编程:实现标准库accumulate函数的并行计算版本
    c++多线程在异常环境下的等待
    c++下基于windows socket的服务器客户端程序(基于UDP协议)
    c++下基于windows socket的单线程服务器客户端程序(基于TCP协议)
    C++解决error C4996报错
    Python读取UTF-8编码文件并使用命令行执行时输出结果的问题
    P4655 [CEOI2017]Building Bridges 题解
  • 原文地址:https://www.cnblogs.com/wyx501/p/10440299.html
Copyright © 2011-2022 走看看