zoukankan      html  css  js  c++  java
  • 『TensorFlow』读书笔记_简单卷积神经网络

    如果你可视化CNN的各层级结构,你会发现里面的每一层神经元的激活态都对应了一种特定的信息,越是底层的,就越接近画面的纹理信息,如同物品的材质。 越是上层的,就越接近实际内容(能说出来是个什么东西的那些信息),如同物品的种类。

    网络结构

    卷积层->池化层->卷积层->池化层->全连接层->Softmax分类器

    卷积层激活函数使用relu

    卷积层relu激活,偏置项使用极小值初始化,防止Relu出现死亡节点

    全连接层激活函数使用relu

    池化层模式使用SAME,所以stride取2,且池化层和卷积层一样,通常设置为SAME模式,本模式下stride=2正好实现1/2变换

    mnist测试集合上,结果可以达到99.2%左右的准确率。

    网络实现

    # Author : Hellcat
    # Time   : 2017/12/7
    
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('../../../Mnist_data',one_hot=True)
    sess = tf.InteractiveSession()
    
    def weight_variable(shape):
        initial = tf.truncated_normal(shape,stddev=0.1)
        return tf.Variable(initial)
    
    def bias_variable(shape):
        # 偏置项使用极小值初始化,防止Relu出现死亡节点(dead neuron)
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)
    
    def conv2d(x, W):
        return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
    
    def max_pool_2x2(x):
        # 2x2池化,步长为2
        return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
    
    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    
    # 5x5滤波器,1通道,32特征图
    W_conv1 = weight_variable([5,5,1,32])
    b_conv1 = bias_variable([32])
    
    h_conv1 = tf.nn.relu((conv2d(x_image, W_conv1) + b_conv1))
    h_pool1 = max_pool_2x2(h_conv1)
    
    # 5x5滤波器,32通道,64特征图
    W_conv2 = weight_variable([5,5,32,64])
    b_conv2 = bias_variable([64])
    
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)
    
    # 28x28,经过2次步长为2的最大池化(SAME),大小变为28/2/2,即7x7
    W_fc1 = weight_variable([7*7*64,1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    
    # dropout层
    keep_prob = tf.placeholder(tf.float32)
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
    
    W_fc2 = weight_variable([1024, 10])
    b_fc2 = bias_variable([10])
    y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
    
    # axis=1,按行来计算
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv),axis=1))
    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    
    correct_prediction = tf.equal(tf.argmax(y_conv,axis=1), tf.argmax(y_,axis=1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    tf.global_variables_initializer().run()
    for i in range(20000):
        batch = mnist.train.next_batch(50)
        train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
        if i % 100 == 0:
            train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})
            print('step {0} traning accuracy {1:.3f}'.format(i,train_accuracy))
    
    print('test accuracy {}'.format(accuracy.eval(
        feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0})))
    

    收敛情况还不错,前1000轮结果如下,

    step 0 traning accuracy 0.040
    step 100 traning accuracy 0.940
    step 200 traning accuracy 0.940
    step 300 traning accuracy 0.980
    step 400 traning accuracy 0.980
    step 500 traning accuracy 0.900
    step 600 traning accuracy 0.920
    step 700 traning accuracy 0.960
    step 800 traning accuracy 1.000
    step 900 traning accuracy 0.960
    step 1000 traning accuracy 1.000
    ……

    最后几轮结果如下,

    step 19000 traning accuracy 1.000
    step 19100 traning accuracy 1.000
    step 19200 traning accuracy 1.000
    step 19300 traning accuracy 1.000
    step 19400 traning accuracy 0.980
    step 19500 traning accuracy 1.000
    step 19600 traning accuracy 1.000
    step 19700 traning accuracy 1.000
    step 19800 traning accuracy 1.000
    step 19900 traning accuracy 1.000
    test accuracy 0.9927999973297119

  • 相关阅读:
    iou与giou对比
    Linux学习第一天 vim
    奖励加分申请
    人月神话阅读笔记3
    5.27
    5.26
    5.25
    5.23
    5.22
    5.21
  • 原文地址:https://www.cnblogs.com/hellcat/p/7999376.html
Copyright © 2011-2022 走看看