zoukankan      html  css  js  c++  java
  • day13-TensorFlow简单神经网络实现手写数字识别

    
    # coding=utf-8
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    def numberRead():
        # 获取数据
        mnist = input_data.read_data_sets("../data/day06/", one_hot=True)
    
        # 1、准备数据集
        with tf.variable_scope("data"):
            # 准备占位符
            x = tf.placeholder(tf.float32,shape=[None,784])
            y_true = tf.placeholder(tf.int64,shape=[None,10])
    
            # 构建一个全连接层的网络,即权重和偏置
            weight = tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
            bias = tf.Variable(tf.random_normal([10],mean=0.0,stddev=1.0))
    
        # 2、构建模型
        with tf.variable_scope("model"):
            # None*784 乘 784*10 得到的结果为 None*10 即对应十个目标值
            y_predict = tf.matmul(x,weight) + bias
    
        # 3、模型参数计算
        with tf.variable_scope("model_soft_corss"):
            # 计算交叉熵损失
            softmax = tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict)
            # 计算损失平均值
            loss = tf.reduce_mean(softmax)
    
        # 4、梯度下降(反向传播算法)优化模型
        with tf.variable_scope("model_better"):
            tarin_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    
        # 5、计算准确率
        with tf.variable_scope("model_acc"):
            # 计算出每个样本是否预测成功,结果为:[1,0,1,0,0,0,....,1]
            equal_list = tf.equal(tf.argmax(y_true,1),tf.argmax(y_predict,1))
    
            # 计算出准确率,先将预测是否成功换为float可以得到详细的准确率
            acc = tf.reduce_mean(tf.cast(equal_list,tf.float32))
    
    
        # 6、准备工作
        # 定义变量初始化op
        init_op = tf.global_variables_initializer()
        # 定义哪些变量记录
        tf.summary.scalar("losses",loss)
        tf.summary.scalar("acces",acc)
        tf.summary.histogram("weightes",weight)
        tf.summary.histogram("biases",bias)
        merge = tf.summary.merge_all()
    
        # 开启会话运行
        with tf.Session() as sess:
            # 变量初始化
            sess.run(init_op)
    
            # 开启记录
            filewriter = tf.summary.FileWriter("../summary/day06/",graph=sess.graph)
    
            for i in range(2500):
                # 准备数据
                mnist_x, mnist_y = mnist.train.next_batch(50)
    
                # 开始训练
                sess.run([tarin_op],feed_dict={x:mnist_x,y_true:mnist_y})
    
                # 得出训练的准确率,注意还需要将数据填入
                print("第%d次训练,准确率为:%f" % ((i+1),sess.run(acc, feed_dict={x: mnist_x, y_true: mnist_y})))
    
                # 写入每步训练的值
                summary = sess.run(merge,feed_dict={x:mnist_x,y_true:mnist_y})
                filewriter.add_summary(summary,i)
    
    
    
    
        return None
    
    
    if __name__ == '__main__':
        numberRead()
    
    
    
    
    
    

    mnist数据集获取地址:http://yann.lecun.com/exdb/mnist/

    训练效果:
    {{uploading-image-661248.png(uploading...)}}

  • 相关阅读:
    经典代码JSKeyword查看(M。。。$)的哦!
    django处理websocket
    产品所有者也应该是Scrum教练吗?
    google的javascript编码规范
    python 处理websocket
    [转] 虚拟座谈会:TDD有多美?
    python 数字相关
    google的python编码规范
    python 函数相关
    python推荐的模块结构
  • 原文地址:https://www.cnblogs.com/wuren-best/p/14311027.html