zoukankan      html  css  js  c++  java
  • keras可视化pydot graphviz问题

    1. 安装

    pip install graphviz
    
    pip install pydot
    
    pip install pydot-ng  # 版本兼容需要
    
    # 测试一下
    from keras.utils.visualize_util import plot

    2. 使用:

    #!/usr/bin/env python
    # coding=utf-8
    
    """
    利用keras cnn进行端到端的验证码识别, 简单直接暴力。
    迭代100次可以达到95%的准确率,但是很容易过拟合,泛化能力糟糕, 除了增加训练数据还没想到更好的方法.
    
    __autho__: jkmiao
    __email__: miao1202@126.com
    ___date__:2017-02-08
    
    """
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Flatten, Activation, LSTM, Reshape
    from keras.layers import Convolution2D, MaxPooling2D
    from PIL import Image
    import os, random
    import numpy as np
    from keras.models import model_from_json
    from util import CharacterTable
    from keras.callbacks import ModelCheckpoint
    from sklearn.model_selection import train_test_split
    from keras.utils.visualize_util import plot
    
    
    def load_data(path='img/clearNoise/'):
        fnames = [os.path.join(path, fname) for fname in os.listdir(path) if fname.endswith('jpg')]
        random.shuffle(fnames)
        data, label = [], []
        for fname in fnames:
            imgLabel = fname.split('/')[-1].split('_')[0]
            imgM = np.array(Image.open(fname).convert('L'))
            imgM = 1 * (imgM>180)
            data.append(imgM.reshape((imgM.shape[0], imgM.shape[1], 1)))
            label.append(imgLabel.lower())
        return np.array(data), label
    
    ctable = CharacterTable()
    data, label = load_data()
    label_onehot = np.zeros((len(label), 216))
    for i, lb in enumerate(label):
        label_onehot[i,:] = ctable.encode(lb)
    print data.shape
    print label_onehot.shape
    
    x_train, x_test, y_train, y_test = train_test_split(data, label_onehot, test_size=0.1)
    
    DEBUG = False
    
    # 建模
    if DEBUG:
        model = Sequential()
        model.add(Convolution2D(32, 5, 5, border_mode='valid', input_shape=(60, 200, 1), name='conv1'))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2,2)))
        model.add(Convolution2D(32, 3, 3, name='conv2'))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2,2)))
        model.add(Flatten())
       # model.add(Reshape((20, 60)))
       # model.add(LSTM(32))
        model.add(Dense(512))
        model.add(Activation('relu'))
        model.add(Dense(216))
        model.add(Activation('softmax'))
    
    else:
        model = model_from_json(open('model/ba_cnn_model2.json').read())
        model.load_weights('model/ba_cnn_model2.h5')
    
    # 编译
    model.compile(loss='mse', optimizer='adam', metrics=['accuracy'], class_mode='categorical')
    model.summary()
    
    # 绘图 plot(model, to_file
    ='model.png', show_shapes=True) # 训练 check_pointer = ModelCheckpoint('./model/train_len_size1.h5', monitor='val_loss', verbose=1, save_best_only=True) model.fit(x_train, y_train, batch_size=32, nb_epoch=5, validation_split=0.1, callbacks=[check_pointer]) json_string = model.to_json() with open('./model/ba_cnn_model2.json', 'w') as fw: fw.write(json_string) model.save_weights('./model/ba_cnn_model2.h5') # 测试 y_pred = model.predict(x_test, verbose=1) cnt = 0 for i in range(len(y_pred)): guess = ctable.decode(y_pred[i]) correct = ctable.decode(y_test[i]) if guess == correct: cnt += 1 if i%10==0: print '--'*10, i print 'y_pred', guess print 'y_test', correct print cnt/float(len(y_pred))
    每天一小步,人生一大步!Good luck~
  • 相关阅读:
    vue父组件促发子组件中的方法
    油猴脚本:油猴脚本自动点击 | 自动检测元素并点击、休眠、顺序执行、单页面也适用
    油猴脚本:使用layer.js mobx lodash jquery
    vue项目统计src目录下代码行数
    常用mobx响应新值变化函数autorun和observe
    uni app使用mobx | uni app状态管理mobx
    File and Code Templates | webstorm代码文件模板 vue typescript
    javascript立即执行函数简单介绍
    VSCode 安装GitLens插件不生效问题
    常用的浅拷贝实现方法
  • 原文地址:https://www.cnblogs.com/jkmiao/p/6408610.html
Copyright © 2011-2022 走看看