zoukankan      html  css  js  c++  java
  • 第二节,mnist手写字体识别

    1、获取mnist数据集,得到正确的数据格式

    mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

    2、定义网络大小:图片的大小是28*28,784个像素点,输入神经元为784个,输出0~9个数,输出神经元为10个

    n_input =784
    n_layer1 = 10
    examples_to_show = 10 #显示的测试图像个数
    x_data = tf.placeholder(tf.float32,[None,n_input])
    y_data = tf.placeholder(tf.float32,[None,n_layer1])
     
    3、添加层函数
    inputsize:输入神经元的个数weights:权重biases:偏置值activation_function:激活函数
    输出:该层的进行激活后的神经元,下一个层的输入
    def addlayer(inputsize,weights,biases,activation_function=None):
        output = tf.add(tf.matmul(x_data,weights),biases)
        if activation_function == None:
            return tf.nn.sigmoid(output)
        else:
            return tf.nn.softmax(output)
     
    4、构建网络
    #预测输出
    添加隐藏层
    y_pre=addlayer(x_data,layer1_weights,layer1_biases,activation_function=tf.nn.softmax)
    y_true=y_data
    #反向
    cross_entropy=-tf.reduce_sum(y_true * tf.log(y_pre))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    for i in range(1000):
         batch_xs,batch_ys =mnist.train.next_batch(100)
         sess.run(train_step,feed_dict={x_data:batch_xs , y_data:batch_ys})
         if (i%50==0):
            print ("loss : ",sess.run(cross_entropy,feed_dict={x_data:batch_xs , y_data:batch_ys}))
            # 这个越大越好
            print ("prediction acc : ",compute_acc(mnist.test.images[:100], mnist.test.labels[:100]))
    5、测试集计算模型预测准确率
    def compute_acc(x_input_test,y_true_test):
        y_pre_test = sess.run(y_pre,feed_dict={x_data:x_input_test, y_data:y_true_test})
        correct_prediction=tf.equal(tf.argmax(y_pre_test,1),tf.argmax(y_true_test,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        result = sess.run(accuracy, feed_dict={x_data:x_input_test , y_data:y_true_test})
        return result
    6、结果
  • 相关阅读:
    一个简单的CSS3+js 实现3D BOX
    jquery $.extend()扩展插件获取焦点或失去焦点事件
    菜单滑动
    全选反选
    纯css,编写菜单移入效果
    登录窗口抖动效果
    [WCF]WCF起航
    FastReport 数据过滤
    [Oracle]TRIGGER
    两种递归方法的比较
  • 原文地址:https://www.cnblogs.com/wyx501/p/10440299.html
Copyright © 2011-2022 走看看