zoukankan      html  css  js  c++  java
  • 『TensorFlow』读书笔记_简单卷积神经网络

    如果你可视化CNN的各层级结构,你会发现里面的每一层神经元的激活态都对应了一种特定的信息,越是底层的,就越接近画面的纹理信息,如同物品的材质。 越是上层的,就越接近实际内容(能说出来是个什么东西的那些信息),如同物品的种类。

    网络结构

    卷积层->池化层->卷积层->池化层->全连接层->Softmax分类器

    卷积层激活函数使用relu

    卷积层relu激活,偏置项使用极小值初始化,防止Relu出现死亡节点

    全连接层激活函数使用relu

    池化层模式使用SAME,所以stride取2,且池化层和卷积层一样,通常设置为SAME模式,本模式下stride=2正好实现1/2变换

    mnist测试集合上,结果可以达到99.2%左右的准确率。

    网络实现

    # Author : Hellcat
    # Time   : 2017/12/7
    
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('../../../Mnist_data',one_hot=True)
    sess = tf.InteractiveSession()
    
    def weight_variable(shape):
        initial = tf.truncated_normal(shape,stddev=0.1)
        return tf.Variable(initial)
    
    def bias_variable(shape):
        # 偏置项使用极小值初始化,防止Relu出现死亡节点(dead neuron)
        initial = tf.constant(0.1, shape=shape)
        return tf.Variable(initial)
    
    def conv2d(x, W):
        return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
    
    def max_pool_2x2(x):
        # 2x2池化,步长为2
        return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
    
    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    
    # 5x5滤波器,1通道,32特征图
    W_conv1 = weight_variable([5,5,1,32])
    b_conv1 = bias_variable([32])
    
    h_conv1 = tf.nn.relu((conv2d(x_image, W_conv1) + b_conv1))
    h_pool1 = max_pool_2x2(h_conv1)
    
    # 5x5滤波器,32通道,64特征图
    W_conv2 = weight_variable([5,5,32,64])
    b_conv2 = bias_variable([64])
    
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)
    
    # 28x28,经过2次步长为2的最大池化(SAME),大小变为28/2/2,即7x7
    W_fc1 = weight_variable([7*7*64,1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    
    # dropout层
    keep_prob = tf.placeholder(tf.float32)
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
    
    W_fc2 = weight_variable([1024, 10])
    b_fc2 = bias_variable([10])
    y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
    
    # axis=1,按行来计算
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv),axis=1))
    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    
    correct_prediction = tf.equal(tf.argmax(y_conv,axis=1), tf.argmax(y_,axis=1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    tf.global_variables_initializer().run()
    for i in range(20000):
        batch = mnist.train.next_batch(50)
        train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
        if i % 100 == 0:
            train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})
            print('step {0} traning accuracy {1:.3f}'.format(i,train_accuracy))
    
    print('test accuracy {}'.format(accuracy.eval(
        feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0})))
    

    收敛情况还不错,前1000轮结果如下,

    step 0 traning accuracy 0.040
    step 100 traning accuracy 0.940
    step 200 traning accuracy 0.940
    step 300 traning accuracy 0.980
    step 400 traning accuracy 0.980
    step 500 traning accuracy 0.900
    step 600 traning accuracy 0.920
    step 700 traning accuracy 0.960
    step 800 traning accuracy 1.000
    step 900 traning accuracy 0.960
    step 1000 traning accuracy 1.000
    ……

    最后几轮结果如下,

    step 19000 traning accuracy 1.000
    step 19100 traning accuracy 1.000
    step 19200 traning accuracy 1.000
    step 19300 traning accuracy 1.000
    step 19400 traning accuracy 0.980
    step 19500 traning accuracy 1.000
    step 19600 traning accuracy 1.000
    step 19700 traning accuracy 1.000
    step 19800 traning accuracy 1.000
    step 19900 traning accuracy 1.000
    test accuracy 0.9927999973297119

  • 相关阅读:
    通过Eclipse生成可运行的jar包
    消息队列原理概念扫盲
    为mutable类型的容器(array,set等)添加kvo,有点麻烦,供参考和了解下吧
    iOS archive(归档)的总结 (序列化和反序列化,持久化到文件)
    http相关概念在iOS中的使用介绍
    AutoLayout技术选型和应用
    addChildViewController相关api深入剖析
    SymmetricDS 完全配置安装手册
    决策树之 C4.5 算法
    决策树之 ID3 算法
  • 原文地址:https://www.cnblogs.com/hellcat/p/7999376.html
Copyright © 2011-2022 走看看