zoukankan      html  css  js  c++  java
  • 记录MNIST采用卷积方式实现与理解

    从时间上来说,这篇文章写的完了,因为这个实验早就做完了;但从能力上来说,这篇文章出现的早了,因为很多地方我都还没有理解。如果不现在写,不知道什么时候会有时间是其一,另外一个原因是怕自己过段时间忘记。

     1 #!/usr/bin/env python
     2 # -*- coding: utf-8 -*-
     3 
     4 # @Author  : mario
     5 # @File    : mnist_faltung.py
     6 # @Project : base
     7 # @Time    : 2018-12-19 14:11:38
     8 # @Desc    : File is ...
     9 
    10 import tensorflow as tf
    11 from tensorflow.examples.tutorials.mnist import input_data
    12 
    13 mnist = input_data.read_data_sets("data/", one_hot=True)
    14 
    15 
    16 def init_weight_variable(shape):
    17     initial = tf.truncated_normal(shape, stddev=0.1)
    18     return tf.Variable(initial)
    19 
    20 
    21 def init_bias_variable(shape):
    22     initial = tf.constant(0.1, shape=shape)
    23     return tf.Variable(initial)
    24 
    25 
    26 def conv2d(x, W):
    27     return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME")
    28 
    29 
    30 def max_pool_2x2(x):
    31     return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
    32 
    33 
    34 x = tf.placeholder(tf.float32, [None, 784])
    35 y_ = tf.placeholder(tf.float32, [None, 10])
    36 
    37 W_conv_1 = init_weight_variable([5, 5, 1, 32])
    38 b_conv_1 = init_bias_variable([32])
    39 
    40 x_image = tf.reshape(x, [-1, 28, 28, 1])
    41 
    42 h_conv_1 = tf.nn.relu(conv2d(x_image, W_conv_1) + b_conv_1)
    43 h_pool_1 = max_pool_2x2(h_conv_1)
    44 
    45 W_conv_2 = init_weight_variable([5, 5, 32, 64])
    46 b_conv_2 = init_bias_variable([64])
    47 
    48 h_conv_2 = tf.nn.relu(conv2d(h_pool_1, W_conv_2) + b_conv_2)
    49 h_pool_2 = max_pool_2x2(h_conv_2)
    50 
    51 W_fc_1 = init_weight_variable([7 * 7 * 64, 1024])
    52 b_fc_1 = init_bias_variable([1024])
    53 
    54 h_pool_flat = tf.reshape(h_pool_2, [-1, 7 * 7 * 64])
    55 h_fc_1 = tf.nn.relu(tf.matmul(h_pool_flat, W_fc_1) + b_fc_1)
    56 
    57 keep_prob = tf.placeholder(tf.float32)
    58 h_fc1_drop = tf.nn.dropout(h_fc_1, keep_prob)
    59 
    60 W_fc_2 = init_weight_variable([1024, 10])
    61 b_fc_2 = init_bias_variable([10])
    62 
    63 y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc_2) + b_fc_2)
    64 
    65 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
    66 train_step = tf.train.AdagradOptimizer(1e-4).minimize(cross_entropy)
    67 
    68 sess = tf.InteractiveSession()
    69 init = tf.global_variables_initializer()
    70 sess.run(init)
    71 
    72 for _ in range(20000):
    73     batch = mnist.train.next_batch(50)
    74     train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
    75 
    76 correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
    77 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    78 
    79 print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

    python版本:Python 3.6.5 (v3.6.5:f59c0932b4, Mar 28 2018, 05:52:31) ;tensor flow版本:1.12.0

    10、11行:导入必要模块

    13行:加载本地的数据

    16~18行:定义初始化权重变量函数

    21~23行:定义初始化偏置变量函数

    26~27行:定义一个步长为1,边距为0的2x2的卷积函数

    30~31行:定义一个2x2的池化函数

    34、35行:定义占位符x和y_,其中x是为了接收原始数据,y_是为了接受原始数据标签

    37、38行:初始化第一层卷积权重和偏置量

    40行:重塑原始数据结构,我们要把[n,784]这样的数据结构转换成卷积需要的[n,28,28,1],这里的n是指数据量,原始数据是将28x28像素的图片展开为784,卷积相当于我们先将数据还原为28x28,最后的1是指通道数

    42、43行:进行卷积操作,并将卷积结果池化

    45~49行:进行第二次卷积操作,同样也是将卷积结果池化

    51~61行:使用ReLU,其中57行和58行是为了防止过拟合

    63行:使用softmax算法确定其分类

    65行:计算交叉熵

    66行:利用交叉熵,调用AdagradOptimizer算法,训练模型

    68~70行:启用session,初始化变量

    72~74行:每次50个训练20000次

    76~79行:评估模型识别率

    也是遇到了很多的问题,但几乎都是因为不理解代码造成的,虽然现在代码是改对了,但是不理解的地方还是有很多,而且很多概念也是不理解,并且不知道实际上是做了什么操作,比如说卷积、池化等,倒是做了什么?感觉这个还是需要后续了解的。“路漫漫其修远兮,吾将上下而求索”用在这里再合适不过了。

  • 相关阅读:
    CentOS 7 镜像下载
    Ambari+HDP生产集群搭建(二)
    elasticsearch-head 关闭窗口服务停止解决方案
    git提交错误 error: failed to push some refs to
    git提交错误 git config --global user.email "you@example.com" git config --global user.name "Your Name
    Java SE入门(一)——变量与数据类型
    markdown基本语法
    numpy的基本API(四)——拼接、拆分、添加、删除
    数理统计(二)——Python中的概率分布API
    统计学习方法与Python实现(三)——朴素贝叶斯法
  • 原文地址:https://www.cnblogs.com/ben-mario/p/10180978.html
Copyright © 2011-2022 走看看