zoukankan      html  css  js  c++  java
  • 12 使用卷积神经网络识别手写数字

        看代码:

     1 import tensorflow as tf
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 # 下载训练和测试数据
     5 mnist = input_data.read_data_sets('MNIST_data/', one_hot = True)
     6 
     7 # 创建session
     8 sess = tf.Session()
     9 
    10 # 占位符
    11 x = tf.placeholder(tf.float32, shape=[None, 784]) # 每张图片28*28,共784个像素
    12 y_ = tf.placeholder(tf.float32, shape=[None, 10]) # 输出为0-9共10个数字,其实就是把图片分为10类
    13 
    14 # 权重初始化
    15 def weight_variable(shape):
    16     initial = tf.truncated_normal(shape, stddev=0.1) # 使用截尾正态分布的随机数初始化权重,标准偏差是0.1(噪音)
    17     return tf.Variable(initial)
    18 
    19 def bias_variable(shape):
    20     initial = tf.constant(0.1, shape = shape) # 使用一个小正数初始化偏置,避免出现偏置总为0的情况
    21     return tf.Variable(initial)
    22 
    23 # 卷积和集合
    24 def conv2d(x, W): # 计算2d卷积
    25     return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
    26 
    27 def max_pool_2x2(x): # 计算最大集合
    28     return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    29 
    30 # 第一层卷积
    31 W_conv1 = weight_variable([5, 5, 1, 32]) # 为每个5*5小块计算32个特征
    32 b_conv1 = bias_variable([32])
    33 
    34 x_image = tf.reshape(x, [-1, 28, 28, 1]) # 将图片像素转换为4维tensor,其中二三维是宽高,第四维是像素
    35 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    36 h_pool1 = max_pool_2x2(h_conv1)
    37 
    38 # 第二层卷积
    39 W_conv2 = weight_variable([5, 5, 32, 64])
    40 b_conv2 = bias_variable([64])
    41 
    42 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    43 h_pool2 = max_pool_2x2(h_conv2)
    44 
    45 # 密集层
    46 W_fc1 = weight_variable([7 * 7 * 64, 1024]) # 创建1024个神经元对整个图片进行处理
    47 b_fc1 = bias_variable([1024])
    48 
    49 h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    50 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    51 
    52 # 退出(为了减少过度拟合,在读取层前面加退出层,仅训练时有效)
    53 keep_prob = tf.placeholder(tf.float32)
    54 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
    55 
    56 # 读取层(最后我们加一个像softmax表达式那样的层)
    57 W_fc2 = weight_variable([1024, 10])
    58 b_fc2 = bias_variable([10])
    59 
    60 y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
    61 
    62 # 预测类和损失函数
    63 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv)) # 计算偏差平均值
    64 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) # 每一步训练
    65 
    66 # 评估
    67 correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
    68 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    69 sess.run(tf.global_variables_initializer())
    70 
    71 for i in range(1000):
    72     batch = mnist.train.next_batch(50)
    73     if i%10 == 0:
    74         train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0}, session = sess) # 每10次训练计算一次精度
    75         print("步数 %d, 精度 %g"%(i, train_accuracy))
    76     train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}, session = sess)
    77 
    78 # 关闭
    79 sess.close()

    执行上面的代码后输出:

    Extracting MNIST_data/train-images-idx3-ubyte.gz
    Extracting MNIST_data/train-labels-idx1-ubyte.gz
    Extracting MNIST_data/t10k-images-idx3-ubyte.gz
    Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
    步数 0, 精度 0.12
    步数 10, 精度 0.34
    步数 20, 精度 0.52
    步数 30, 精度 0.56
    步数 40, 精度 0.6
    步数 50, 精度 0.74
    步数 60, 精度 0.74
    步数 70, 精度 0.78
    步数 80, 精度 0.82

    ..........

    步数 900, 精度 0.96
    步数 910, 精度 0.98
    步数 920, 精度 0.96
    步数 930, 精度 0.98
    步数 940, 精度 0.98
    步数 950, 精度 0.9
    步数 960, 精度 0.98
    步数 970, 精度 0.9
    步数 980, 精度 1
    步数 990, 精度 0.9

        可以看到,使用卷积神经网络训练1000次可以让精度达到95%以上,据说训练20000次精度可以达到99.2%以上。由于CPU不行,太耗时间,就不训练那么多了。大家可以跟使用softmax训练识别手写数字进行对比。《07 训练Tensorflow识别手写数字

        参考资料

        1、Deep MNIST for Experts:https://www.tensorflow.org/get_started/mnist/pros

  • 相关阅读:
    form编码方式application/x-www-form-urlencoded和multipart/form-data的区别
    CentOS开启telnet服务
    借助英语搞清会计中“借”/“贷”的含义(转载)
    乘法器的Verilog HDL实现(转载)
    Meth | 关闭mac自带apache的启动
    Meth | Git冲突:commit your changes or stash them before you can merge. 解决办法
    Meth | Git 避免重复输入用户名和密码方法
    Meth | git Please move or remove them before you can merge
    Meth | git 常用命令
    Meth | 小团队git开发模式
  • 原文地址:https://www.cnblogs.com/tengge/p/6920144.html
Copyright © 2011-2022 走看看