zoukankan      html  css  js  c++  java
  • Keras GRU 文字识别

    GRU(Gated Recurrent Unit)是LSTM的一个变体,也能克服RNN无法很好处理远距离依赖的问题。


    GRU的结构跟LSTM类似,不过增加了让三个门层也接收细胞状态的输入,是常用的LSTM变体之一。



    LSTM核心模块:




    这一核心模块在GRU中变为:


       


    CTC网络结构定义:

    def get_model(height,nclass):
        
        input = Input(shape=(height,None,1),name='the_input')
        m = Conv2D(64,kernel_size=(3,3),activation='relu',padding='same',name='conv1')(input)
        m = MaxPooling2D(pool_size=(2,2),strides=(2,2),name='pool1')(m)
        m = Conv2D(128,kernel_size=(3,3),activation='relu',padding='same',name='conv2')(m)
        m = MaxPooling2D(pool_size=(2,2),strides=(2,2),name='pool2')(m)
        m = Conv2D(256,kernel_size=(3,3),activation='relu',padding='same',name='conv3')(m)
        m = Conv2D(256,kernel_size=(3,3),activation='relu',padding='same',name='conv4')(m)
    
        m = ZeroPadding2D(padding=(0,1))(m)
        m = MaxPooling2D(pool_size=(2,2),strides=(2,1),padding='valid',name='pool3')(m)
    
        m = Conv2D(512,kernel_size=(3,3),activation='relu',padding='same',name='conv5')(m)
        m = BatchNormalization(axis=1)(m)
        m = Conv2D(512,kernel_size=(3,3),activation='relu',padding='same',name='conv6')(m)
        m = BatchNormalization(axis=1)(m)
        m = ZeroPadding2D(padding=(0,1))(m)
        m = MaxPooling2D(pool_size=(2,2),strides=(2,1),padding='valid',name='pool4')(m)
        m = Conv2D(512,kernel_size=(2,2),activation='relu',padding='valid',name='conv7')(m)
    
        m = Permute((2,1,3),name='permute')(m)
        m = TimeDistributed(Flatten(),name='timedistrib')(m)
    
        m = Bidirectional(GRU(rnnunit,return_sequences=True),name='blstm1')(m)
        m = Dense(rnnunit,name='blstm1_out',activation='linear')(m)
        m = Bidirectional(GRU(rnnunit,return_sequences=True),name='blstm2')(m)
        y_pred = Dense(nclass,name='blstm2_out',activation='softmax')(m)
    
        basemodel = Model(inputs=input,outputs=y_pred)
    
        labels = Input(name='the_labels', shape=[None,], dtype='float32')
        input_length = Input(name='input_length', shape=[1], dtype='int64')
        label_length = Input(name='label_length', shape=[1], dtype='int64')
        loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
        model = Model(inputs=[input, labels, input_length, label_length], outputs=[loss_out])
        sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
        #model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
        model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
        model.summary()
        return model,basemodel

    ____________________________________________________________________________________________________
    Layer (type)                     Output Shape          Param #     Connected to                     
    ====================================================================================================
    the_input (InputLayer)           (None, 32, None, 1)   0                                            
    ____________________________________________________________________________________________________
    conv1 (Conv2D)                   (None, 32, None, 64)  640         the_input[0][0]                  
    ____________________________________________________________________________________________________
    pool1 (MaxPooling2D)             (None, 16, None, 64)  0           conv1[0][0]                      
    ____________________________________________________________________________________________________
    conv2 (Conv2D)                   (None, 16, None, 128) 73856       pool1[0][0]                      
    ____________________________________________________________________________________________________
    pool2 (MaxPooling2D)             (None, 8, None, 128)  0           conv2[0][0]                      
    ____________________________________________________________________________________________________
    conv3 (Conv2D)                   (None, 8, None, 256)  295168      pool2[0][0]                      
    ____________________________________________________________________________________________________
    conv4 (Conv2D)                   (None, 8, None, 256)  590080      conv3[0][0]                      
    ____________________________________________________________________________________________________
    zero_padding2d_1 (ZeroPadding2D) (None, 8, None, 256)  0           conv4[0][0]                      
    ____________________________________________________________________________________________________
    pool3 (MaxPooling2D)             (None, 4, None, 256)  0           zero_padding2d_1[0][0]           
    ____________________________________________________________________________________________________
    conv5 (Conv2D)                   (None, 4, None, 512)  1180160     pool3[0][0]                      
    ____________________________________________________________________________________________________
    batch_normalization_1 (BatchNorm (None, 4, None, 512)  16          conv5[0][0]                      
    ____________________________________________________________________________________________________
    conv6 (Conv2D)                   (None, 4, None, 512)  2359808     batch_normalization_1[0][0]      
    ____________________________________________________________________________________________________
    batch_normalization_2 (BatchNorm (None, 4, None, 512)  16          conv6[0][0]                      
    ____________________________________________________________________________________________________
    zero_padding2d_2 (ZeroPadding2D) (None, 4, None, 512)  0           batch_normalization_2[0][0]      
    ____________________________________________________________________________________________________
    pool4 (MaxPooling2D)             (None, 2, None, 512)  0           zero_padding2d_2[0][0]           
    ____________________________________________________________________________________________________
    conv7 (Conv2D)                   (None, 1, None, 512)  1049088     pool4[0][0]                      
    ____________________________________________________________________________________________________
    permute (Permute)                (None, None, 1, 512)  0           conv7[0][0]                      
    ____________________________________________________________________________________________________
    timedistrib (TimeDistributed)    (None, None, 512)     0           permute[0][0]                    
    ____________________________________________________________________________________________________
    blstm1 (Bidirectional)           (None, None, 512)     1181184     timedistrib[0][0]                
    ____________________________________________________________________________________________________
    blstm1_out (Dense)               (None, None, 256)     131328      blstm1[0][0]                     
    ____________________________________________________________________________________________________
    blstm2 (Bidirectional)           (None, None, 512)     787968      blstm1_out[0][0]                 
    ____________________________________________________________________________________________________
    blstm2_out (Dense)               (None, None, 5531)    2837403     blstm2[0][0]                     
    ____________________________________________________________________________________________________
    the_labels (InputLayer)          (None, None)          0                                            
    ____________________________________________________________________________________________________
    input_length (InputLayer)        (None, 1)             0                                            
    ____________________________________________________________________________________________________
    label_length (InputLayer)        (None, 1)             0                                            
    ____________________________________________________________________________________________________
    ctc (Lambda)                     (None, 1)             0           blstm2_out[0][0]                 
                                                                       the_labels[0][0]                 
                                                                       input_length[0][0]               
                                                                       label_length[0][0]               
    ====================================================================================================
    Total params: 10,486,715
    Trainable params: 10,486,699


    模型: 模型包含5500个中文字符,包括常用汉字、大小写英文字符、标点符号、特殊符号(@、¥、&)等,可以在现有模型基础上继续训练。

    训练: 样本保存在data文件夹下,使用LMDB格式; train.py是训练文件,可以选择保存模型权重或模型结构+模型权重,训练结果保存在models文件夹下。

    测试: test.py是中文OCR测试文件


    识别效果:


    济南华富锻造有限公司



    夺得铜牌后,福民爱流下了激动的泪水。“石川



    Itturnedoutthat328girswerenamedAbcdeintheUnitedstates




    工程(含训练模型)地址:  http://download.csdn.net/download/dcrmg/10248818


  • 相关阅读:
    装饰器
    异常处理与断言
    例子:对象构造函数指定类型传入参数(描述符与装饰器的应用)
    Python的描述符
    全新开始fighting
    函数相关知识
    集合的介绍以及简单方法
    列表,元组,字典类的常见简单方法
    Python简单字符串函数介绍
    聚合函数及分组查询及F&Q
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9411740.html
Copyright © 2011-2022 走看看