zoukankan      html  css  js  c++  java
  • 基于MNIST数据集使用TensorFlow训练一个包含一个隐含层的全连接神经网络

    包含一个隐含层的全连接神经网络结构如下:

                                                                                                   包含一个隐含层的神经网络结构图

    以MNIST数据集为例,以上结构的神经网络训练如下:

    #coding=utf-8
    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf
    
    # 加载数据
    mnist = input_data.read_data_sets('/home/workspace/python/tf/data/mnist', one_hot=True)
    """
    # 创建模型
    x = tf.placeholder(tf.float32, [None, 784])
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.matmul(x, W) + b
    """
    x = tf.placeholder(tf.float32, [None, 784])
    W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1))
    b1 = tf.Variable(tf.zeros([500]))
    W2 = tf.Variable(tf.truncated_normal([500, 10], stddev=0.1))
    b2 = tf.Variable(tf.zeros([10]))
    layer1 = tf.nn.relu(tf.matmul(x, W1) + b1)
    y = tf.matmul(layer1, W2) + b2
    
    # 正确的样本标签
    y_ = tf.placeholder(tf.float32, [None, 10])
    
    # 损失函数选择softmax后的交叉熵,结果作为y的输出
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    
    # 训练过程
    for _ in range(5000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        if _%1000 == 0:
            # 使用测试集评估准确率
            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            print (sess.run(accuracy, feed_dict = {x: mnist.test.images,
                                                      y_: mnist.test.labels}))

    注意:权重向量初始化时使用tf.truncated_normal,而不要使用tf.zeros

    以上代码大概能得到97.98%的准确率。

    软件版本


    TensorFlow 1.0.1  +  Python 2.7.12

  • 相关阅读:
    年轻人的第一个 Spring Boot 应用,太爽了!
    面试问我 Java 逃逸分析,瞬间被秒杀了。。
    Spring Boot 配置文件 bootstrap vs application 到底有什么区别?
    坑爹的 Java 可变参数,把我整得够惨。。
    6月来了,Java还是第一!
    Eclipse 最常用的 10 组快捷键,个个牛逼!
    Spring Cloud Eureka 自我保护机制实战分析
    今天是 Java 诞生日,Java 24 岁了!
    厉害了,Dubbo 正式毕业!
    Spring Boot 2.1.5 正式发布,1.5.x 即将结束使命!
  • 原文地址:https://www.cnblogs.com/eczhou/p/7860527.html
Copyright © 2011-2022 走看看