zoukankan      html  css  js  c++  java
  • CNN 手写数字识别

     

    1. 知识点准备

    在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念。

    a. 卷积

    关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性,以下面这个一维的卷积为例子:

    第一个特性是稀疏连接。可以看到, layer m 上的每一个节点都只与 layer m-1 对应区域的三个节点相连接。这个局部范围也叫感受野。第二个特性是相同颜色的线条代表了相同的权重,即权重共享。这样做有什么好处呢?一方面权重共享可以极大减小参数的数目,学习起来更加有效,另一方面,相同的权重可以让过滤器不受图像位置的影响来检测图像的特性,从而使 CNN 具有更强的泛化能力。

    b. 池化

    理论上,我们将图像利用不同的过滤器通过卷积之后得到了多个卷积之后的图像,然后直接利用这些图像进行分类,但是这样计算量太大了。利用池化操作可以将数据量减小,同时在一定程度上保留原有的图像特征。关于 pooling, 概念更加简单了,详情可以参考这里。池化又可以分为平均池化和最大池化,这里我们将采用最大池化。注意到,池化的区域是不重叠的,卷积的感受野是重叠的。

    2. 卷积神经网络的搭建

    下图是手写数字识别中采用的 lenet-5 简单的卷积神经网络模型:

    1. 原图是 28 × 28 的手写数字图片,通过第一次 20 个 5 × 5 的卷积核之后,得到 20 张卷积图片。卷积核的权重是取一定范围内的随机值,这样,一张 28 × 28 的图片就变为 20 张 (28-5+1)× (28-5+1)=24×24 的图片了。

    2. 将 24×24 的图片进行 2 × 2 的最大池化,得到 20 张 12 × 12 的图片。该图片的像素还需要进行 tanh 函数的变换才能作为下一个卷积层的输入。

    3. 将 tanh 变化之后的 12 × 12 大小的图片同样进行 20 × 50 个 5 × 5 的卷积操作之后得到 50 张 (12-5+1)× (12-5+1) = 8 × 8 的图片。

    4. 将 8×8 的图片进行 2×2 的最大池化,得到 50 张 4×4 的图片,再经过 tanh 函数进行归一化处理,就可以作为 MLP 的 800 个输入了。

    5. 余下来就是 MLP 的训练工作了。

    3. LR, MLP,CNN 识别代码

    已经训练好的模型系数的下载地址

    三种方法识别手写数字的代码:

      1 import cPickle
      2 
      3 import numpy
      4 
      5 import theano
      6 import theano.tensor as T
      7 from theano.tensor.signal import downsample
      8 from theano.tensor.nnet import conv
      9 
     10 ########################################
     11 # define the classifer constructs
     12 ########################################
     13 
     14 class LogisticRegression(object):
     15     def __init__(self, input, W=None, b=None):
     16 
     17         if W is None:
     18             fle = open("../model_param/lr_sgd_best.pkl")
     19             W, b = cPickle.load(fle)
     20             fle.close()
     21 
     22         self.W = W
     23         self.b = b
     24 
     25         self.outputs = T.nnet.softmax(T.dot(input, self.W) + b)
     26 
     27         self.pred = T.argmax(self.outputs, axis=1)
     28 
     29 class MLP(object):
     30     def __init__(self, input, params=None):
     31         if params is None:
     32             fle = open("../model_param/mlp_best.pkl")
     33             params = cPickle.load(fle)
     34             fle.close()
     35 
     36         self.hidden_W, self.hidden_b, self.lr_W, self.lr_b = params
     37 
     38         self.hiddenlayer = T.tanh(T.dot(input, self.hidden_W) + self.hidden_b)
     39 
     40         self.outputs = T.nnet.softmax(T.dot(self.hiddenlayer, self.lr_W) 
     41                     + self.lr_b)
     42 
     43         self.pred = T.argmax(self.outputs, axis=1)
     44 
     45 class CNN(object):
     46     def __init__(self, input, params=None):
     47         if params is None: 
     48             fle = open("../model_param/cnn_best.pkl")
     49             params = cPickle.load(fle)
     50             fle.close()
     51 
     52         ################
     53         self.layer3_W, self.layer3_b, self.layer2_W, self.layer2_b, 
     54             self.layer1_W, self.layer1_b, self.layer0_W, self.layer0_b = params
     55 
     56         # compute layer0 
     57         self.conv_out0 = conv.conv2d(input=input, filters=self.layer0_W)
     58 #                    filter_shape=(20, 1, 5, 5), image_shape=(1, 1, 
     59 #                        28, 28))
     60         self.pooled_out0 = downsample.max_pool_2d(input=self.conv_out0, 
     61                     ds=(2, 2), ignore_border=True)
     62         self.layer0_output = T.tanh(self.pooled_out0 + 
     63                     self.layer0_b.dimshuffle('x', 0, 'x', 'x'))
     64 
     65         # compute layer1 
     66         self.conv_out1 = conv.conv2d(input=self.layer0_output, filters=self.layer1_W)
     67 #                    filter_shape=(50, 20, 5, 5), image_shape=(1, 20, 
     68 #                        12, 12))
     69         self.pooled_out1 = downsample.max_pool_2d(input=self.conv_out1, 
     70                     ds=(2, 2), ignore_border=True)
     71         self.layer1_output = T.tanh(self.pooled_out1 + 
     72                     self.layer1_b.dimshuffle('x', 0, 'x', 'x'))
     73         
     74         # compute layer2
     75         self.layer2_input = self.layer1_output.flatten(2)
     76 
     77         self.layer2_output = T.tanh(T.dot(self.layer2_input, self.layer2_W) + 
     78                     self.layer2_b)
     79 
     80         # compute layer3
     81         self.outputs = T.nnet.softmax(T.dot(self.layer2_output, self.layer3_W)
     82                     + self.layer3_b)
     83 
     84         self.pred = T.argmax(self.outputs, axis=1)
     85 
     86 ########################################
     87 # build classifier
     88 ########################################
     89 
     90 def lr(input):
     91     input.shape = 1, -1
     92 
     93     x = T.fmatrix('x')
     94     classifer = LogisticRegression(input=x)
     95 
     96     get_p_y = theano.function(inputs=[x], outputs=classifer.outputs)
     97     pred_y = theano.function(inputs=[x], outputs=classifer.pred)
     98     return (get_p_y(input), pred_y(input))
     99 
    100 def mlp(input):
    101     input.shape = 1, -1
    102 
    103     x = T.fmatrix('x')
    104     classifer = MLP(input=x)
    105 
    106     get_p_y = theano.function(inputs=[x], outputs=classifer.outputs)
    107     pred_y = theano.function(inputs=[x], outputs=classifer.pred)
    108     return (get_p_y(input), pred_y(input))
    109 
    110 def cnn(input):
    111     input.shape = (1, 1, 28, 28)
    112     x = T.dtensor4('x')
    113     classifer = CNN(input=x)
    114     get_p_y = theano.function(inputs=[x], outputs=classifer.outputs)
    115     pred_y = theano.function(inputs=[x], outputs=classifer.pred)
    116     return (get_p_y(input), pred_y(input))
    View Code
  • 相关阅读:
    weblogic详解
    Java实现视频网站的视频上传、视频转码、及视频播放功能(ffmpeg)
    Java上传视频(mencoder)
    input标签type="file"上传文件的css样式
    jQuery系列:选择器
    jQuery系列:Ajax
    Sql Server系列:规范化及基本设计
    Sql Server系列:查询分页语句
    Sql Server系列:通用表表达式CTE
    Sql Server系列:子查询
  • 原文地址:https://www.cnblogs.com/daniel-D/p/3203459.html
Copyright © 2011-2022 走看看