zoukankan      html  css  js  c++  java
  • cnn进行端到端的验证码识别改进

    keras_cnn.py 训练及建模

    #!/usr/bin/env python
    # coding=utf-8
    
    """
    利用keras cnn进行端到端的验证码识别, 简单直接暴力。
    迭代100次可以达到95%的准确率,但是很容易过拟合,泛化能力糟糕, 除了增加训练数据还没想到更好的方法.
    
    __autho__: jkmiao
    __email__: miao1202@126.com
    ___date__:2017-02-08
    
    """
    from keras.models import Model
    from keras.layers import Dense, Dropout, Flatten, Input, merge
    from keras.layers import Convolution2D, MaxPooling2D
    from keras.preprocessing.image import ImageDataGenerator
    from PIL import Image
    import os, random
    import numpy as np
    from keras.models import model_from_json
    from util import CharacterTable
    from keras.callbacks import EarlyStopping
    from sklearn.model_selection import train_test_split
    # from keras.utils.visualize_util import plot
    
    def load_data(path='img/clearNoise/'):
        fnames = [os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('jpg')]
        random.shuffle(fnames)
        data, label = [], []
        for i, fname in enumerate(fnames):
            imgLabel = fname.split('/')[-1].split('_')[0]
            if len(imgLabel)!=6:
                print 'error: ', fname
                continue
            imgM = np.array(Image.open(fname).convert('L'))
            imgM = 1 * (imgM>180)
            data.append(imgM.reshape((60, 200, 1)))
            label.append(imgLabel.lower())
        return np.array(data), label
    
    ctable = CharacterTable()
    data, label = load_data()
    print data[0].max(), data[0].min()
    label_onehot = np.zeros((len(label), 216))
    for i, lb in enumerate(label):
        label_onehot[i,:] = ctable.encode(lb)
    print data.shape, data[-1].max(), data[-1].min()
    print label_onehot.shape
    
    
    datagen = ImageDataGenerator(shear_range=0.08, zoom_range=0.08, horizontal_flip=False,
                                rotation_range=5, width_shift_range=0.06, height_shift_range=0.06)
    
    datagen.fit(data)
    
    x_train, x_test, y_train, y_test = train_test_split(data, label_onehot, test_size=0.1)
    
    DEBUG = False
    
    # 建模
    if DEBUG:
        input_img = Input(shape=(60, 200, 1))
    
        inner = Convolution2D(16, 7, 7, border_mode='same', activation='relu')(input_img)
        inner = MaxPooling2D(pool_size=(2,2))(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        inner = MaxPooling2D(pool_size=(2,2))(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        encoder_a = Flatten()(inner)
    
        inner = Convolution2D(16, 5, 5, border_mode='same', activation='relu')(input_img)
        inner = MaxPooling2D(pool_size=(2,2))(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        inner = MaxPooling2D(pool_size=(2,2))(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        encoder_b = Flatten()(inner)
        
        inner = Convolution2D(16, 3, 3, border_mode='same', activation='relu')(input_img)
        inner = MaxPooling2D(pool_size=(2,2))(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        inner = MaxPooling2D(pool_size=(2,2))(inner)
        inner = Convolution2D(16, 3, 3, border_mode='same')(inner)
        encoder_c = Flatten()(inner)
        
        input = merge([encoder_a, encoder_b, encoder_c], mode='concat', concat_axis=-1)
        drop = Dropout(0.5)(input)
        flatten = Dense(216)(drop)
        flatten = Dropout(0.5)(flatten)
        
        fc1 = Dense(36, activation='softmax')(flatten) 
        fc2 = Dense(36, activation='softmax')(flatten) 
        fc3 = Dense(36, activation='softmax')(flatten) 
        fc4 = Dense(36, activation='softmax')(flatten) 
        fc5 = Dense(36, activation='softmax')(flatten) 
        fc6 = Dense(36, activation='softmax')(flatten) 
        merged = merge([fc1, fc2, fc3, fc4, fc5, fc6], mode='concat', concat_axis=-1)
    
        model = Model(input=input_img, output=merged)
    else:
        model = model_from_json(open('model/ba_cnn_model3.json').read())
        model.load_weights('model/ba_cnn_model3.h5')
    
    # 编译
    # model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    model.summary()
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    # plot(model, to_file='model3.png', show_shapes=True)
    
    # 训练
    
    early_stopping = EarlyStopping(monitor='val_loss', patience=5)
    
    model.fit_generator(datagen.flow(x_train, y_train, batch_size=32), samples_per_epoch=len(x_train), nb_epoch=50, validation_data=(x_test, y_test), callbacks=[early_stopping] )
    
    json_string = model.to_json()
    with open('./model/ba_cnn_model4.json', 'w') as fw:
        fw.write(json_string)
    model.save_weights('./model/ba_cnn_model4.h5')
    
    print 'done saved model cnn3'
    
    # 测试
    y_pred = model.predict(x_test, verbose=1)
    cnt = 0
    for i in range(len(y_pred)):
        guess = ctable.decode(y_pred[i])
        correct = ctable.decode(y_test[i])
        if guess == correct:
            cnt += 1
        if i%10==0:
            print '--'*10, i
            print 'y_pred', guess
            print 'y_test', correct
    print cnt/float(len(y_pred))

    apicode.py  模型使用

    #!/usr/bin/env python
    # coding=utf-8
    
    from util import CharacterTable
    from keras.models import model_from_json
    from PIL import Image
    import matplotlib.pyplot as plt
    import os
    import numpy as np
    from prepare import clearNoise
    
    def img2vec(fname):
        data = []
        img = clearNoise(fname).convert('L')
        imgM = 1.0 * (np.array(img)>180)
        print imgM.max(), imgM.min()
        data.append(imgM.reshape((60, 200, 1)))
        return np.array(data), imgM
    
    ctable = CharacterTable()
    
    model = model_from_json(open('model/ba_cnn_model4.json').read())
    model.load_weights('model/ba_cnn_model4.h5')
    
    def test(path):
        fnames = [ os.path.join(path, fname) for fname in os.listdir(path) ][:50]
        correct = 0
        for idx, fname in enumerate(fnames, 1):
            data, imgM = img2vec(fname)
            y_pred = model.predict(data)
            result = ctable.decode(y_pred[0])
            label = fname.split('/')[-1].split('_')[0]
            if result == label:
                correct += 1
                print 'correct', fname
            else:
                print result, label
            print 'accuracy: ',idx, float(correct)/idx
            print '=='*20
    #        plt.subplot(121)
    #        plt.imshow(Image.open(fname).convert('L'), plt.cm.gray)
    #        plt.title(fname)
    #
    #        plt.subplot(122)
    #        plt.imshow(imgM, plt.cm.gray)
    #        plt.title(result)
    #        plt.show()
    
    test('test')
    每天一小步,人生一大步!Good luck~
  • 相关阅读:
    为什么我们不要 .NET 程序员
    Jquery异步请求数据实例代码
    关系数据库中表的基本属性有哪些
    利用VC从DLL传递消息到EXE
    新实体与原实体之间为一对多关系
    本人C++ Builder开发的仿Windows桌面应用程序源码
    delphi窗体动态设计 在系统运行时动态更改控件属性
    DB.ASP 用Javascript写ASP很灵活很好用很easy
    CrazyScan Satellite scan software 卫星扫描
    delphi中窗体半透明效果如何实现
  • 原文地址:https://www.cnblogs.com/jkmiao/p/6531265.html
Copyright © 2011-2022 走看看