zoukankan      html  css  js  c++  java
  • 7.交叉熵

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    # 载入数据集
    mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
    
    # 批次大小
    batch_size = 64
    # 计算一个周期一共有多少个批次
    n_batch = mnist.train.num_examples // batch_size
    
    # 定义两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    
    # 创建一个简单的神经网络:784-10
    W = tf.Variable(tf.truncated_normal([784,10], stddev=0.1))
    b = tf.Variable(tf.zeros([10]) + 0.1)
    prediction = tf.nn.softmax(tf.matmul(x,W)+b)
    
    # 二次代价函数
    # loss = tf.losses.mean_squared_error(y, prediction)
    # 交叉熵
    loss = tf.losses.softmax_cross_entropy(y, prediction)
    
    # 使用梯度下降法
    train = tf.train.GradientDescentOptimizer(0.3).minimize(loss)
    
    # 结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
    # 求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    
    with tf.Session() as sess:
        # 变量初始化
        sess.run(tf.global_variables_initializer())
        # 周期epoch:所有数据训练一次,就是一个周期
        for epoch in range(21):
            for batch in range(n_batch):
                # 获取一个批次的数据和标签
                batch_xs,batch_ys = mnist.train.next_batch(batch_size)
                sess.run(train,feed_dict={x:batch_xs,y:batch_ys})
            # 每训练一个周期做一次测试
            acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
            print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    Iter 0,Testing Accuracy 0.7473
    Iter 1,Testing Accuracy 0.8413
    Iter 2,Testing Accuracy 0.9066
    Iter 3,Testing Accuracy 0.9113
    Iter 4,Testing Accuracy 0.9143
    Iter 5,Testing Accuracy 0.9168
    Iter 6,Testing Accuracy 0.9199
    Iter 7,Testing Accuracy 0.9201
    Iter 8,Testing Accuracy 0.9202
    Iter 9,Testing Accuracy 0.9213
    Iter 10,Testing Accuracy 0.921
    Iter 11,Testing Accuracy 0.9205
    Iter 12,Testing Accuracy 0.9214
    Iter 13,Testing Accuracy 0.923
    Iter 14,Testing Accuracy 0.9237
    Iter 15,Testing Accuracy 0.9238
    Iter 16,Testing Accuracy 0.924
    Iter 17,Testing Accuracy 0.9231
    Iter 18,Testing Accuracy 0.9246
    Iter 19,Testing Accuracy 0.925
    Iter 20,Testing Accuracy 0.9253
  • 相关阅读:
    The Brain vs Deep Learning Part I: Computational Complexity — Or Why the Singularity Is Nowhere Near
    unity3d NGUI多场景共用界面制作
    python第三方库系列之十九--python測试使用的mock库
    oracle之单行函数
    Andropid自己定义组件-坐标具体解释
    [WebGL入门]二,開始WebGL之前,先了解一下canvas
    【BZOJ2318】【spoj4060】game with probability Problem 概率DP
    苹果改版之后,关于隐私协议加入的问题解决方式
    Binary Tree Level Order Traversal II
    首届中国智慧城市协同创新峰会将于6月20日在大连隆重举行
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11605466.html
Copyright © 2011-2022 走看看