zoukankan      html  css  js  c++  java
  • 第十二节 简单神经网络实现手写数字识别

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data  # TensorFlow中提供的演示数据
    
    """
    神经网络的一些重要概念:
        全连接层:第N层和第N-1层之间的部分
        输出:有几个类别就有几个输出,每一个样本对每个类别都有一个概率值
        softmax:计算每个样本对所有类别的概率
        交叉熵:计算每个样本得出概率后,其损失值是多少
        反向传播:通过计算的损失值,不断的去调整权重,使得损失值达到最小
    """
    FLAGS = tf.app.flags.FLAGS
    tf.app.flags.DEFINE_integer("is_train", 1, "指定程序是预测还是训练")
    def full_connected():
    
        # 获取真实数据,会自动创建目录
        mnist = input_data.read_data_sets("./mnist/input_data/", one_hot=True)
    
        # 1.建立数据的占位符 特征值X [None, 784]  目标值y_true [None, 10]
        with tf.variable_scope("data"):
            # None代表不确定的样本数
            x = tf.placeholder(tf.float32, [None, 784])
            y_true = tf.placeholder(tf.float32, [None, 10])
    
        # 2.建立一个全连接层的神经网络 权重w [784, 10] 偏置b [10]
        with tf.variable_scope("fc_model"):
            # 随机初始化权重和偏置
            weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0, name="w"))
            bias = tf.Variable(tf.constant(0.0, shape=[10]))
    
            # 预测None个样本的输出结果,matmul矩阵相乘
            y_predict = tf.matmul(x, weight) + bias
    
        # 3.求所有样本的损失,然后求损失
        with tf.variable_scope("soft_cross"):
            # 求平均交叉熵损失
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
    
        # 4.梯度下降求损失
        with tf.variable_scope("optimizer"):
            # 0.1是学习率,minimize表示求最小损失
            train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    
        # 5.计算准确率,每一个样本对应每个特征值都有一个概率,tf.argmax(y_true, 1), tf.argmax(y_predict, 1)返回的是分别是真实值和预测值的正确的下标,equal判断两个下标是否一致,一样,则这个样本被标为1
        with tf.variable_scope("acc"):
            equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
    
            # equal_list None个样本 [1, 0, 1, 1, 0, 0.....]
            accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
    
        # 收集变量
        tf.summary.scalar("losses", loss)
        tf.summary.scalar("acc", accuracy)
    
        # 收集高纬度变量
        tf.summary.histogram("weights", weight)
        tf.summary.histogram("biases", bias)
    
        # 定义初始化变量op
        init_op = tf.global_variables_initializer()
    
        # 定义一个合并变量op
        merged = tf.summary.merge_all()
    
        # 创建一个saver保存训练好的模型
        saver = tf.train.Saver()
    
        # 开启会话进行训练
        with tf.Session() as sess:
            # 初始化变量
            sess.run(init_op)
    
            #建立events文件,然后写入
            filewriter = tf.summary.FileWriter("./tmp/summary/test/", graph=sess.graph)
    
            if FLAGS.is_train ==1:
                # 迭代步数训练,更新参数
                for i in range(2000):
                    # 取出真实存在的特征值和目标值,50表示50个样本作为一个批次
                    mnist_x, mnist_y = mnist.train.next_batch(50)
    
                    # 运行训练op
                    sess.run(train_op, feed_dict={x:mnist_x, y_true:mnist_y})
    
                    # 写入每步训练的值
                    summary = sess.run(merged, feed_dict={x:mnist_x, y_true:mnist_y})
                    filewriter.add_summary(summary, i)
    
                    # feed_dict这里的参数是必须,但是没有实际意义
                    print("训练第{}步,准确率:{}".format(i, sess.run(accuracy, feed_dict={x:mnist_x, y_true:mnist_y})))
                # 保存模型
                saver.save(sess, "./tmp/ckpt/fc_model")
            else:
                # 加载模型
                saver.restore(sess, "tmp/ckpt/fc_model")
                # 进行预测
                for i in range(100):
                    x_test, y_test = mnist.test.next_batch(50)
                    print("第{}张图片,首先数字目标是:{},预测结果是:{}".format(i, tf.argmax(y_test, 1).eval(), tf.argmax(sess.run(y_predict, feed_dict={x:x_test, y_true:y_test}), 1).eval()))
    
        return None
    if __name__ == "__main__":
        full_connected()
  • 相关阅读:
    Quora 用了哪些技术(转)
    Instagram的技术探索2(转)
    Sharding & IDs at Instagram(转)
    2010“架构师接龙”问答--杨卫华VS赵劼(转)
    架构师接龙 汇总(转)
    如何成为一名软件架构师(转)
    网站架构资料集(转)
    技术好重要吗?
    洞洞那么大-悲伤那么小
    教你玩转XSS漏洞
  • 原文地址:https://www.cnblogs.com/kogmaw/p/12599467.html
Copyright © 2011-2022 走看看