zoukankan      html  css  js  c++  java
  • tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛

    tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用。

    数据目录在data,data下放了汉字识别图片:

    data$ ls
    0  1  10  11  12  13  14  15  16  2  3  4  5  6  7  8  9
    datag$ ls 0
    xxx.png yyy.png ....

    代码:

     如果将get model里的模型层数加非常深,训练时候很可能不会收敛,精度一直停留下1%以内。

    # -*- coding: utf-8 -*-
    
    
    from __future__ import division, print_function, absolute_import
    
    import os
    import numpy as np
    import pickle
    import tflearn
    
    from PIL import Image
    from tflearn.layers.core import input_data, dropout, fully_connected
    from tflearn.layers.conv import conv_2d, max_pool_2d, avg_pool_2d
    from tflearn.layers.merge_ops import merge
    from tflearn.layers.estimator import regression
    from tflearn.data_utils import to_categorical, shuffle
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report, confusion_matrix
    from tflearn.layers.conv import highway_conv_2d, max_pool_2d
    from tflearn.layers.normalization import local_response_normalization, batch_normalization
    
    def resize_image(in_image, new_width, new_height, out_image=None,
                     resize_mode=Image.ANTIALIAS):
        """ Resize an image.
        Arguments:
            in_image: `PIL.Image`. The image to resize.
            new_ `int`. The image new width.
            new_height: `int`. The image new height.
            out_image: `str`. If specified, save the image to the given path.
            resize_mode: `PIL.Image.mode`. The resizing mode.
     
        Returns:
            `PIL.Image`. The resize image.
        """
        img = in_image.resize((new_width, new_height), resize_mode)
        if out_image:
            img.save(out_image)
        return img
    
    
    def convert_color(in_image, mode):
        """ Convert image color with provided `mode`. """
        return in_image.convert(mode)
    
    
    def pil_to_nparray(pil_image):
        """ Convert a PIL.Image to numpy array. """
        pil_image.load()
        return np.asarray(pil_image, dtype="float32")
    
    
    def iterbrowse(path):
        for home, dirs, files in os.walk(path):
            for filename in files:
                yield os.path.join(home, filename)
    
    
    def directory_to_samples(directory, flags):
        """ Read a directory, and list all subdirectories files as class sample """
        samples = []
        targets = []
        # label class is from 0 !!!
        label = 0
        try:  # Python 2
            classes = sorted(os.walk(directory).next()[1])
        except Exception:  # Python 3
            classes = sorted(os.walk(directory).__next__()[1])
        for c in classes:
            c_dir = os.path.join(directory, c)
            try:  # Python 2
                walk = os.walk(c_dir).next()
            except Exception:  # Python 3
                walk = os.walk(c_dir).__next__()
            for sample in walk[2]:
                if any(flag in sample for flag in flags):
                    samples.append(os.path.join(c_dir, sample))
                    targets.append(label)
            label += 1
        return samples, targets
    
    
    # Get the pixel from the given image
    def get_pixel(image, i, j):
        # Inside image bounds?
        width, height = image.size
        if i > width or j > height:
          return None
    
        # Get Pixel
        pixel = image.getpixel((i, j))
        return pixel
    
    
    # Create a Grayscale version of the image
    def convert_to_one_channel(image):
        # !!! I assume that the png file is grayscale. And R == G == B !!!! So I check it...
        """
        for i in range(len(image)):
            for j in range(len(image[i])):
                pixel = image[i][j]
                # Get R, G, B values (This are int from 0 to 255)
                assert len(pixel) == 3
                red = pixel[0]
                green = pixel[1]
                blue = pixel[2]
                assert red == green == blue
                assert 0 <= red <= 1
        """
        # Just extract 1 channel data
        return image[:, :, [0]]
    
    
    
    def image_dirs_to_samples(directory, resize=None, convert_gray=False,
                              filetypes=None):
        print("Starting to parse images...")
        samples, targets = directory_to_samples(directory, flags=filetypes)
        for i, s in enumerate(samples):
            print("Process %d th file %s" % (i+1, s))
            samples[i] = Image.open(s)  # Load an image, returns PIL.Image.
            if resize:
                ######################## TODO #######################
                samples[i] = resize_image(samples[i], resize[0],
                                          resize[1])
            ######################### TODO ####################### convert to more data
            # if convert_gray:
            #    samples[i] = convert_color(samples[i], 'L')
            samples[i] = pil_to_nparray(samples[i])
            samples[i] /= 255.  # ormalize a list of sample image data in the range of 0 to 1
            samples[i] = convert_to_one_channel(samples[i]) # just want 1 channel data
        print("Parsing Done!")
        return samples, targets
    
    
    def load_data(dirname, resize_pics=(128, 128), shuffle_data=True):
        dataset_file = os.path.join(dirname, 'data.pkl')
        try:
            X, Y, org_labels = pickle.load(open(dataset_file, 'rb'))
        except Exception:
            # X, Y = image_dirs_to_samples(os.path.join(dirname, 'train/'), resize_pics, False, ['.jpg', '.png'])
            X, Y = image_dirs_to_samples(dirname, resize_pics, False,
                                         ['.jpg', '.png'])  # TODO, memory is too small to load all data
            org_labels = Y
            Y = to_categorical(Y, np.max(Y) + 1)  # First class is '0', Convert class vector (integers from 0 to nb_classes)
            if shuffle_data:
                X, Y, org_labels = shuffle(X, Y, org_labels)
            pickle.dump((X, Y, org_labels), open(dataset_file, 'wb'))
        return X, Y, org_labels
    
    
    class EarlyStoppingCallback(tflearn.callbacks.Callback):
        def __init__(self, val_acc_thresh):
            # Store a validation accuracy threshold, which we can compare against
            # the current validation accuracy at, say, each epoch, each batch step, etc.
            self.val_acc_thresh = val_acc_thresh
    
        def on_epoch_end(self, training_state):
            """
            This is the final method called in trainer.py in the epoch loop.
            We can stop training and leave without losing any information with a simple exception.
            """
            # print dir(training_state)
            print("Terminating training at the end of epoch", training_state.epoch)
            if training_state.val_acc >= self.val_acc_thresh and training_state.acc_value >= self.val_acc_thresh:
                raise StopIteration
    
        def on_train_end(self, training_state):
            """
            Furthermore, tflearn will then immediately call this method after we terminate training,
            (or when training ends regardless). This would be a good time to store any additional
            information that tflearn doesn't store already.
            """
            print("Successfully left training! Final model accuracy:", training_state.acc_value)
    
    
    def get_model(width, height, classes=40):
        # TODO, modify model
        # Real-time data preprocessing
        img_prep = tflearn.ImagePreprocessing()
        # Real-time data preprocessing
        img_prep = tflearn.ImagePreprocessing()
        img_prep.add_featurewise_zero_center(per_channel=True)
        img_prep.add_featurewise_stdnorm()
        network = input_data(shape=[None, width, height, 1], data_preprocessing=img_prep)  # if RGB, 224,224,3
        network = conv_2d(network, 32, 3, activation='relu')
        network = max_pool_2d(network, 2)
        network = conv_2d(network, 64, 3, activation='relu')
        network = conv_2d(network, 64, 3, activation='relu')
        network = max_pool_2d(network, 2)
        network = fully_connected(network, 512, activation='relu')
        network = dropout(network, 0.5)
        network = fully_connected(network, classes, activation='softmax')
        network = regression(network, optimizer='adam',
                             loss='categorical_crossentropy',
                             learning_rate=0.001)
        model = tflearn.DNN(network, tensorboard_verbose=0)
        return model
    
    
    if __name__ == "__main__":
        width, height = 32, 32
        X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
        trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
        print("sample data:")
        print(trainX[0])
        print(trainY[0])
        print(testX[-1])
        print(testY[-1])
    
        model = get_model(width, height, classes=100)
    
        filename = 'cnn_handwrite-acc0.8.tflearn'
        # try to load model and resume training
        #try:
        #    model.load(filename)
        #    print("Model loaded OK. Resume training!")
        #except:
        #    pass
    
        # Initialize our callback with desired accuracy threshold.
        early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.9)
        try:
            model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                      snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                      show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
        except StopIteration as e:
            print("OK, stop iterate!Good!")
    
        model.save(filename)
    
        # predict all data and calculate confusion_matrix
        model.load(filename)
    
        pro_arr =model.predict(X)
        predict_labels = np.argmax(pro_arr, axis=1)
        print(classification_report(org_labels, predict_labels))
        print(confusion_matrix(org_labels, predict_labels))
    
    
    

     运行效果:100个汉字2分钟就可以达到95%精度。

    ---------------------------------
    Run id: cnn_handwrite
    Log directory: /tmp/tflearn_logs/
    ---------------------------------
    Preprocessing... Calculating mean over all dataset (this may take long)...
    Mean: [ 0.89235026] (To avoid repetitive computation, add it to argument 'mean' of `add_featurewise_zero_center`)
    ---------------------------------
    Preprocessing... Calculating std over all dataset (this may take long)...
    STD: 0.192279 (To avoid repetitive computation, add it to argument 'std' of `add_featurewise_stdnorm`)
    ---------------------------------
    Training samples: 19094
    Validation samples: 4774
    --
    Training Step: 597  | total loss: 0.70288 | time: 40.959ss
    | Adam | epoch: 001 | loss: 0.70288 - acc: 0.7922 | val_loss: 0.54380 - val_acc: 0.8460 -- iter: 19094/19094
    --
    Terminating training at the end of epoch 1
     Training Step: 1194  | total loss: 0.48860 | time: 40.245s
    | Adam | epoch: 002 | loss: 0.48860 - acc: 0.8783 | val_loss: 0.37020 - val_acc: 0.8923 -- iter: 19094/19094
    --
    Terminating training at the end of epoch 2
    Training Step: 1791  | total loss: 0.35790 | time: 41.315ss
    | Adam | epoch: 003 | loss: 0.35790 - acc: 0.9090 | val_loss: 0.34719 - val_acc: 0.9049 -- iter: 19094/19094
    --
    Terminating training at the end of epoch 3
    Successfully left training! Final model accuracy: 0.908959209919
    OK, stop iterate!Good!
                 precision    recall  f1-score   support
    
              0       1.00      0.99      0.99       239
              1       0.95      0.96      0.96       237
              2       0.91      0.98      0.94       240
              3       0.90      0.98      0.94       239
              4       0.96      0.98      0.97       239
              5       0.94      0.97      0.96       239
              6       0.98      0.98      0.98       239
              7       0.84      0.99      0.91       240
              8       0.99      0.87      0.93       239
              9       0.95      0.98      0.96       239
             10       0.97      0.94      0.96       240
             11       0.95      0.98      0.97       240
             12       0.92      0.99      0.95       240
             13       0.95      0.96      0.96       239
             14       0.94      0.94      0.94       236
             15       0.94      0.97      0.96       240
             16       0.94      0.98      0.96       240
             17       0.97      0.99      0.98       240
             18       0.94      0.93      0.94       240
             19       1.00      0.95      0.98       239
             20       0.96      0.98      0.97       240
             21       0.98      0.91      0.95       239
             22       0.97      0.95      0.96       239
             23       1.00      0.97      0.98       239
             24       0.94      0.98      0.96       240
             25       0.98      0.98      0.98       237
             26       0.91      1.00      0.95       239
             27       0.91      0.96      0.93       239
             28       0.97      0.88      0.92       239
             29       1.00      0.98      0.99       240
             30       0.99      0.94      0.96       239
             31       0.97      0.97      0.97       237
             32       0.94      0.98      0.96       236
             33       0.94      0.96      0.95       239
             34       0.98      0.99      0.98       239
             35       0.99      0.98      0.99       240
             36       0.96      0.92      0.94       239
             37       1.00      0.93      0.96       240
             38       0.96      0.99      0.98       238
             39       0.98      0.97      0.97       238
             40       0.92      0.90      0.91       240
             41       0.96      0.97      0.96       237
             42       0.98      0.97      0.97       240
             43       0.95      0.96      0.95       239
             44       0.97      0.96      0.97       239
             45       0.95      0.94      0.95       239
             46       0.93      0.96      0.94       232
             47       0.98      0.91      0.94       237
             48       0.95      0.97      0.96       239
             49       0.97      0.80      0.88       226
             50       0.90      0.95      0.92       240
             51       0.98      0.99      0.99       236
             52       0.96      0.90      0.93       240
             53       0.99      0.96      0.97       235
             54       0.97      0.93      0.95       240
             55       0.99      0.98      0.99       240
             56       0.97      0.92      0.95       239
             57       0.97      0.97      0.97       239
             58       1.00      0.98      0.99       238
             59       0.92      0.98      0.95       240
             60       0.99      0.90      0.94       240
             61       1.00      0.99      0.99       238
             62       0.92      0.95      0.94       239
             63       0.92      0.98      0.95       238
             64       0.98      0.92      0.95       240
             65       0.99      0.92      0.95       239
             66       0.98      0.99      0.99       240
             67       0.95      0.95      0.95       240
             68       0.96      0.98      0.97       239
             69       0.97      0.97      0.97       239
             70       0.98      0.94      0.96       240
             71       0.91      0.96      0.93       239
             72       0.98      0.97      0.97       239
             73       0.99      0.89      0.94       240
             74       0.97      0.99      0.98       237
             75       0.89      0.97      0.92       240
             76       0.97      0.96      0.97       241
             77       0.89      0.91      0.90       240
             78       1.00      0.89      0.94       239
             79       0.90      0.98      0.94       239
             80       0.89      0.96      0.92       240
             81       0.96      0.71      0.81       225
             82       0.95      1.00      0.97       238
             83       0.67      0.96      0.79       239
             84       0.97      0.85      0.91       240
             85       0.95      0.98      0.96       239
             86       0.99      0.93      0.96       240
             87       0.98      0.91      0.94       239
             88       0.97      0.97      0.97       240
             89       0.97      0.94      0.95       239
             90       0.97      0.96      0.96       236
             91       0.91      0.97      0.94       239
             92       0.98      0.95      0.96       240
             93       0.98      0.97      0.98       239
             94       0.98      0.95      0.97       240
             95       0.98      0.99      0.99       239
             96       0.95      0.97      0.96       240
             97       0.98      0.97      0.98       239
             98       0.95      0.98      0.97       237
             99       0.97      0.96      0.97       239
    
    avg / total       0.96      0.95      0.95     23868
    
    [[237   0   0 ...,   0   0   0]
     [  0 228   0 ...,   0   0   0]
     [  0   0 235 ...,   0   0   0]
     ..., 
     [  0   0   0 ..., 233   0   0]
     [  0   0   0 ...,   0 233   0]
     [  0   0   0 ...,   0   0 230]]
    

     更多模型见:http://www.cnblogs.com/bonelee/p/8978060.html

    将上述模型保存并给TensorFlow使用,仅仅在保存模型前加del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:],仅仅保留inference时候的OP(如果需要retrain注意),如下:

        model = get_model(width, height, classes=100)
    
        filename = 'cnn_handwrite-acc0.8.tflearn'
        # try to load model and resume training
        #try:
        #    model.load(filename)
        #    print("Model loaded OK. Resume training!")
        #except:
        #    pass
    
        # Initialize our callback with desired accuracy threshold.
        early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.8)
        try:
            model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                      snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                      show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
        except StopIteration as e:
            print("OK, stop iterate!Good!")
    
        model.save(filename)
        del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
       
        """
        # print op name
        with tf.Session() as sess:
                init_op = tf.initialize_all_variables()
                sess.run(init_op)
                for v in sess.graph.get_operations():
                    print(v.name)
        """
    
        filename = 'cnn_handwrite-acc0.8.infer.tflearn'
        model.save(filename)
    

     参考:http://www.cnblogs.com/bonelee/p/8445261.html 里的脚本,修改:

    output_node_names = "FullyConnected/Softmax"
    通常为:
    output_node_names = "FullyConnected/Softmax"
    或者
    output_node_names = "FullyConnected_1/Softmax"
    output_node_names = "FullyConnected_2/Softmax"
    就看你使用的全连接层数,上面分别是1,2,3层。
    最后,tensorflow里的使用:
    def inference(image):
        print('inference')
        temp_image = Image.open(image).convert('L')
        temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
        temp_image = np.asarray(temp_image) / 255.0
        temp_image = temp_image.reshape([-1, 32, 32, 1])
        from tensorflow.python.platform import gfile
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            with open("frozen_model.pb", "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tensors = tf.import_graph_def(output_graph_def, name="")
                #print tensors
            with tf.Session() as sess:
                init = tf.global_variables_initializer()
                sess.run(init)
                op = sess.graph.get_operations()
                """
                for m in op:
                    print(m.values())
                """
                op = sess.graph.get_tensor_by_name("FullyConnected_1/Softmax:0")
                input_tensor = sess.graph.get_tensor_by_name('InputData/X:0')
                probs = sess.run(op,feed_dict = {input_tensor:temp_image})
                print probs
                
                result = []
                for word in probs:
                      result.append(np.argsort(-word)[:3])
                return result
    
    
    def main(_):
            image_path = './data/test/00098/104405.png'
            #image_path = '../data/00010/17724.png'
            final_predict_val = inference(image_path)
            logger.info('the result info label {0} predict index {1}'.format(98, final_predict_val))
    
    
    

     一般,输入TensorFlow input name默认为InputData/X,但只是op,如果要tensor的话,加上数字0,也就是:InputData/X:0

    同理,FullyConnected_1/Softmax:0。

    最后预测效果:

    [[  8.42533936e-08   1.60850794e-11   2.60133332e-10   2.42555542e-14
        4.96124599e-08   4.45251297e-15   3.98175590e-11   1.64476592e-11
        7.03968351e-13   5.42319011e-12   8.55469237e-11   4.91866422e-13
        1.77282828e-07   4.05237593e-10   3.13049003e-10   1.34780919e-11
        2.05803235e-06   2.87827305e-07   1.47789994e-12   2.53391891e-11
        3.77086790e-13   2.02639586e-10   9.03167027e-13   3.96698889e-11
        1.30850096e-11   5.71980917e-12   3.03487374e-11   2.04132298e-14
        6.25303683e-13   1.46122332e-07   2.17450633e-07   1.69623715e-09
        6.80857757e-12   2.52643609e-13   6.56771096e-11   8.55152287e-16
        1.34496514e-09   1.22644633e-06   1.12011307e-07   7.93476283e-05
        8.24334611e-12   4.77531155e-14   9.39397757e-13   2.38438267e-14
        2.11416329e-10   5.54395712e-08   2.30046147e-12   2.63584043e-10
        4.70621564e-16   5.14432724e-12   6.42602327e-09   1.62485829e-13
        7.39078274e-08   3.19146315e-12   5.25887156e-09   1.35877786e-13
        1.39127886e-13   2.11998293e-13   9.09501097e-09   9.46486750e-07
        2.47498733e-09   2.74523763e-12   1.02716433e-14   1.02069058e-17
        3.09356682e-16   1.51022904e-15   9.34333665e-13   2.62195051e-14
        3.38079781e-16   7.43019903e-13   1.92409091e-13   3.86611994e-13
        2.61276265e-12   1.07969211e-09   1.30814548e-09   2.44038188e-14
        9.79275905e-13   1.41007803e-10   6.15137758e-12   2.08893070e-10
        1.34751668e-14   2.76824767e-15   7.84100464e-16   7.70873335e-15
        5.45704757e-12   3.69386271e-18   2.06012223e-13   1.62567273e-14
        1.54544960e-03   2.05292008e-06   1.31726174e-09   7.04993663e-09
        4.11338266e-03   3.19344110e-07   3.96519717e-05   2.26919351e-12
        2.39114349e-12   2.35558744e-07   9.94213998e-01   1.10125060e-11]]
    the result info label 98 predict index [array([98, 92, 88])]
    
     
     
  • 相关阅读:
    浅谈异或相关性质
    重谈树状数组
    洛谷 U141397 !
    谈谈Sleep和wait的区别
    请描述线程的生命周期
    一个普通main方法的执行,是单线程模式还是多线程模式?为什么?
    创建线程的方式
    一道关于try catch finally返回值的问题
    throw跟throws的区别
    罗列常见的5个非运行时异常
  • 原文地址:https://www.cnblogs.com/bonelee/p/8941654.html
Copyright © 2011-2022 走看看