zoukankan      html  css  js  c++  java
  • 12 使用卷积神经网络识别手写数字

        看代码:

     1 import tensorflow as tf
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 # 下载训练和测试数据
     5 mnist = input_data.read_data_sets('MNIST_data/', one_hot = True)
     6 
     7 # 创建session
     8 sess = tf.Session()
     9 
    10 # 占位符
    11 x = tf.placeholder(tf.float32, shape=[None, 784]) # 每张图片28*28,共784个像素
    12 y_ = tf.placeholder(tf.float32, shape=[None, 10]) # 输出为0-9共10个数字,其实就是把图片分为10类
    13 
    14 # 权重初始化
    15 def weight_variable(shape):
    16     initial = tf.truncated_normal(shape, stddev=0.1) # 使用截尾正态分布的随机数初始化权重,标准偏差是0.1(噪音)
    17     return tf.Variable(initial)
    18 
    19 def bias_variable(shape):
    20     initial = tf.constant(0.1, shape = shape) # 使用一个小正数初始化偏置,避免出现偏置总为0的情况
    21     return tf.Variable(initial)
    22 
    23 # 卷积和集合
    24 def conv2d(x, W): # 计算2d卷积
    25     return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
    26 
    27 def max_pool_2x2(x): # 计算最大集合
    28     return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    29 
    30 # 第一层卷积
    31 W_conv1 = weight_variable([5, 5, 1, 32]) # 为每个5*5小块计算32个特征
    32 b_conv1 = bias_variable([32])
    33 
    34 x_image = tf.reshape(x, [-1, 28, 28, 1]) # 将图片像素转换为4维tensor,其中二三维是宽高,第四维是像素
    35 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    36 h_pool1 = max_pool_2x2(h_conv1)
    37 
    38 # 第二层卷积
    39 W_conv2 = weight_variable([5, 5, 32, 64])
    40 b_conv2 = bias_variable([64])
    41 
    42 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    43 h_pool2 = max_pool_2x2(h_conv2)
    44 
    45 # 密集层
    46 W_fc1 = weight_variable([7 * 7 * 64, 1024]) # 创建1024个神经元对整个图片进行处理
    47 b_fc1 = bias_variable([1024])
    48 
    49 h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    50 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
    51 
    52 # 退出(为了减少过度拟合,在读取层前面加退出层,仅训练时有效)
    53 keep_prob = tf.placeholder(tf.float32)
    54 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
    55 
    56 # 读取层(最后我们加一个像softmax表达式那样的层)
    57 W_fc2 = weight_variable([1024, 10])
    58 b_fc2 = bias_variable([10])
    59 
    60 y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
    61 
    62 # 预测类和损失函数
    63 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv)) # 计算偏差平均值
    64 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) # 每一步训练
    65 
    66 # 评估
    67 correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
    68 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    69 sess.run(tf.global_variables_initializer())
    70 
    71 for i in range(1000):
    72     batch = mnist.train.next_batch(50)
    73     if i%10 == 0:
    74         train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0}, session = sess) # 每10次训练计算一次精度
    75         print("步数 %d, 精度 %g"%(i, train_accuracy))
    76     train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}, session = sess)
    77 
    78 # 关闭
    79 sess.close()

    执行上面的代码后输出:

    Extracting MNIST_data/train-images-idx3-ubyte.gz
    Extracting MNIST_data/train-labels-idx1-ubyte.gz
    Extracting MNIST_data/t10k-images-idx3-ubyte.gz
    Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
    步数 0, 精度 0.12
    步数 10, 精度 0.34
    步数 20, 精度 0.52
    步数 30, 精度 0.56
    步数 40, 精度 0.6
    步数 50, 精度 0.74
    步数 60, 精度 0.74
    步数 70, 精度 0.78
    步数 80, 精度 0.82

    ..........

    步数 900, 精度 0.96
    步数 910, 精度 0.98
    步数 920, 精度 0.96
    步数 930, 精度 0.98
    步数 940, 精度 0.98
    步数 950, 精度 0.9
    步数 960, 精度 0.98
    步数 970, 精度 0.9
    步数 980, 精度 1
    步数 990, 精度 0.9

        可以看到,使用卷积神经网络训练1000次可以让精度达到95%以上,据说训练20000次精度可以达到99.2%以上。由于CPU不行,太耗时间,就不训练那么多了。大家可以跟使用softmax训练识别手写数字进行对比。《07 训练Tensorflow识别手写数字

        参考资料

        1、Deep MNIST for Experts:https://www.tensorflow.org/get_started/mnist/pros

  • 相关阅读:
    Leetcode 811. Subdomain Visit Count
    Leetcode 70. Climbing Stairs
    Leetcode 509. Fibonacci Number
    Leetcode 771. Jewels and Stones
    Leetcode 217. Contains Duplicate
    MYSQL安装第三步报错
    .net 开发WEB程序
    JDK版本问题
    打开ECLIPSE 报failed to load the jni shared library
    ANSI_NULLS SQL语句
  • 原文地址:https://www.cnblogs.com/tengge/p/6920144.html
Copyright © 2011-2022 走看看