zoukankan      html  css  js  c++  java
  • tensorflow创建cnn网络进行中文手写文字识别

    数据集下载地址:http://www.nlpr.ia.ac.cn/databases/handwriting/download.html

    chinese_write_detection.py

    # -*- coding: utf-8 -*-
    import tensorflow as tf
    import os
    import random
    import tensorflow.contrib.slim as slim
    import time
    import numpy as np
    import pickle
    from PIL import Image
    from log_utils import get_logger
    
    logger = get_logger("HandWritten  Practice")
    root_path = 'D:/eclipse-workspace/sxzsb'
    tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
    tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
    tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")
    
    tf.app.flags.DEFINE_integer('charset_size', 3755, "Choose the first `charset_size` character to conduct our experiment.")
    tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
    tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
    tf.app.flags.DEFINE_integer('max_steps', 12002, 'the max training steps ')
    tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
    tf.app.flags.DEFINE_integer('save_steps', 2000, "the steps to save")
    
    tf.app.flags.DEFINE_string('checkpoint_dir', 'D:/eclipse-workspace/sxzsb/checkpoint', 'the checkpoint dir')
    tf.app.flags.DEFINE_string('train_data_dir', 'D:/eclipse-workspace/sxzsb/data/train', 'the train dataset dir(containing png files)')
    tf.app.flags.DEFINE_string('test_data_dir', 'D:/eclipse-workspace/sxzsb/data/test', 'the test dataset dir(containing png files)')
    tf.app.flags.DEFINE_string('log_dir', 'D:/eclipse-workspace/sxzsb/log', 'the logging path)')
    
    tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint')
    tf.app.flags.DEFINE_integer('epoch', 1, 'Number of epoches')
    tf.app.flags.DEFINE_integer('batch_size', 128, 'Validation batch size')
    tf.app.flags.DEFINE_string('mode', 'train', 'Running mode. One of {"train", "valid", "test"}')
    FLAGS = tf.app.flags.FLAGS
    
    
    class DataIterator:
    
        def __init__(self, data_dir):
            # Set FLAGS.charset_size to a small value if available computation power is limited.
            truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
            print(truncate_path)
            self.image_names = []
            for root, sub_folder, file_list in os.walk(data_dir):
                if root < truncate_path:  # some problem here ,because the first root is contain inside ,and there is no file_list
                    self.image_names += [os.path.join(root, file_path) for file_path in file_list]
            random.shuffle(self.image_names)
            self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]  # int("00020") output:20
    
        @property
        def size(self):  #  @property,负责把一个方法变成属性调用的,还可以定义只读属性,只定义getter方法,不定义setter方法就是一个只读属性
            return len(self.labels)
    
        @staticmethod
        def data_augmentation(images):
            if FLAGS.random_flip_up_down:
                images = tf.image.random_flip_up_down(images)
            if FLAGS.random_brightness:
                images = tf.image.random_brightness(images, max_delta=0.3)
            if FLAGS.random_contrast:
                images = tf.image.random_contrast(images, 0.8, 1.2)
            return images
    
        def input_pipeline(self, batch_size, num_epochs=None, aug=False):
            # 1、convert images to a tensor   构造数据queue
            images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
            # 执行tf.convert_to_tensor()的时候,在图上生成了一个Op,Op中保存了传入参数的数据。op经过计算产生tensor
            labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
            input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)
            # 2、 ## queue输出数据
            labels = input_queue[1]
            images_content = tf.read_file(input_queue[0])  # read images from the queue,refer to input_queue
            images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
            if aug:
                images = self.data_augmentation(images)
            new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
            images = tf.image.resize_images(images, new_size)
            # collect batches of images before processing
            # 3、shuffle_batch批量从queu批量读取数据
            image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
                                                              min_after_dequeue=10000)  # produce shunffled batch
            return image_batch, label_batch
    
    
    def build_graph(top_k):
        # with tf.device('/cpu:0'):
        keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')
        images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')
        labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch')
    
        conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')
    # (inputs,num_outputs,[卷积核个数] kernel_size,[卷积核的高度,卷积核的宽]stride=1,padding='SAME',)
        max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')
        conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')
        max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')
        conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')
        max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME')
    
        flatten = slim.flatten(max_pool_3)
        fc1 = slim.fully_connected(tf.nn.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')
        logits = slim.fully_connected(tf.nn.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')
       # logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
       # y表示的是实际类别,y_表示预测结果,这实际上面是把原来的神经网络输出层的softmax和cross_entrop何在一起计算,为了追求速度
        accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))
    
        global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)  # global_step interesting  sharing varialbes
        rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)
        train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step)  #  train_op 包含了训练数据
        probabilities = tf.nn.softmax(logits)  # 上一个用logits是soft_max和cross_entropy一起算的,这次只是算了soft_max输出
    
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('accuracy', accuracy)
        merged_summary_op = tf.summary.merge_all()
        predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)
        accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))  # 这个思路真是清奇!!!看来我回答对了
    
       # return the operator
        return {'images': images,
                'labels': labels,
                'keep_prob': keep_prob,
                'top_k': top_k,
                'global_step': global_step,
                'train_op': train_op,
                'loss': loss,
                'accuracy': accuracy,
                'accuracy_top_k': accuracy_in_top_k,
                'merged_summary_op': merged_summary_op,
                'predicted_distribution': probabilities,
                'predicted_index_top_k': predicted_index_top_k,
                'predicted_val_top_k': predicted_val_top_k}
    
    
    def train():
        print('Begin training')
        train_feeder = DataIterator(data_dir='../data/train/')
        test_feeder = DataIterator(data_dir='../data/test/')
        with tf.Session() as sess:
            # session操作之前启动队列runners才能激活pipelines/input pipeline 并载入数据
            train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)  # num_epochs what's refer to ?
            test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)
            graph = build_graph(top_k=1)  # very important
            sess.run(tf.global_variables_initializer())
            # 4、 ## 启动queue线程
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver()
    
            train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
            test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')
            start_step = 0
            if FLAGS.restore:  # 这里是加载保存好的模型,的到step继续训练
                ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
                if ckpt:
                    saver.restore(sess, ckpt)
                    print("restore from the checkpoint {0}".format(ckpt))
                    start_step += int(ckpt.split('-')[-1])
    
            logger.info(':::Training Start:::')
            try:
                while not coord.should_stop():  ###----
                    start_time = time.time()
                    print(start_time)
                    train_images_batch, train_labels_batch = sess.run([train_images, train_labels])
                    print(len(train_images_batch))
                    feed_dict = {graph['images']: train_images_batch,
                                 graph['labels']: train_labels_batch,
                                 graph['keep_prob']: 0.8}  # keep 80% connection
                    _, loss_val, train_summary, step = sess.run(
                        [graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],
                        feed_dict=feed_dict)
                    train_writer.add_summary(train_summary, step)
                    end_time = time.time()
                    logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))
                    if step > FLAGS.max_steps:
                        break
                    if step % FLAGS.eval_steps == 1:
                        test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
                        feed_dict = {graph['images']: test_images_batch,
                                     graph['labels']: test_labels_batch,
                                     graph['keep_prob']: 1.0}
                        accuracy_test, test_summary = sess.run(
                            [graph['accuracy'], graph['merged_summary_op']],
                            feed_dict=feed_dict)  # 这里的多层括号问题
                        test_writer.add_summary(test_summary, step)
                        logger.info('===============Eval a batch=======================')
                        logger.info('the step {0} test accuracy: {1}'
                                    .format(step, accuracy_test))
                        logger.info('===============Eval a batch=======================')
                    if step % FLAGS.save_steps == 1:
                        logger.info('Save the ckpt of {0}'.format(step))
                        saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),
                                   global_step=graph['global_step'])
            except tf.errors.OutOfRangeError:
                logger.info('==================Train Finished================')
                saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])
            finally:
                coord.request_stop()  # 任何一个线程请求停止,则coord.should_stop()就会返回True ,然后都停下来
            coord.join(threads)
    
    
    def validation():
        print('validation')
        test_feeder = DataIterator(data_dir='../data/test/')
    
        final_predict_val = []
        final_predict_index = []
        groundtruth = []
    
        with tf.Session() as sess:
            test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)
            graph = build_graph(3)
    
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())  # initialize test_feeder's inside state
    
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt is not None:
                saver.restore(sess, ckpt)
                print("restore from the checkpoint {0}".format(ckpt))
    
            logger.info(':::Start validation:::')
            try:
                i = 0
                acc_top_1, acc_top_k = 0.0, 0.0
                while not coord.should_stop():
                    i += 1
                    start_time = time.time()
                    test_images_batch, test_labels_batch = sess.run([test_images, test_labels])
                    feed_dict = {graph['images']: test_images_batch,
                                 graph['labels']: test_labels_batch,
                                 graph['keep_prob']: 1.0}
                    batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],
                                                                           graph['predicted_val_top_k'],
                                                                           graph['predicted_index_top_k'],
                                                                           graph['accuracy'],
                                                                           graph['accuracy_top_k']], feed_dict=feed_dict)
                    final_predict_val += probs.tolist()
                    final_predict_index += indices.tolist()
                    groundtruth += batch_labels.tolist()
                    acc_top_1 += acc_1
                    acc_top_k += acc_k
                    end_time = time.time()
                    logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"
                                .format(i, end_time - start_time, acc_1, acc_k))
    
            except tf.errors.OutOfRangeError:
                logger.info('==================Validation Finished================')
                acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size
                acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size
                logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))
            finally:
                coord.request_stop()
            coord.join(threads)
        return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth}
    
    
    def inference(image):
        print('inference')
        temp_image = Image.open(image).convert('L')
        temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)
        temp_image = np.asarray(temp_image) / 255.0
        temp_image = temp_image.reshape([-1, 64, 64, 1])
        with tf.Session() as sess:
            logger.info('========start inference============')
            # images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])
            # Pass a shadow label 0. This label will not affect the computation graph.
            graph = build_graph(top_k=3)
            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
            if ckpt:
                saver.restore(sess, ckpt)
            predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],
                                                  feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})
        return predict_val, predict_index
    
    
    def main(_):
        print(FLAGS.mode)
        if FLAGS.mode == "train":
            train()
        elif FLAGS.mode == 'validation':
            dct = validation()  # thinking what is "dct"
            result_file = 'result.dict'
            logger.info('Write result into {0}'.format(result_file))
            with open(result_file, 'wb') as f:
                pickle.dump(dct, f)
            logger.info('Write file ends')
        elif FLAGS.mode == 'inference':
            image_path = '../data/test/00159/75700.png'
            final_predict_val, final_predict_index = inference(image_path)  # figure out what is inference
            logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,
                                                                                             final_predict_val))
    
    
    if __name__ == "__main__":
        tf.app.run()  # It's just a very quick wrapper that handles flag parsing and then dispatches to your own main.

    log_utils.py

    # -*- coding:utf-8 -*-
    import os, os.path as osp
    import time
    
    
    def strftime(t=None):
        return time.strftime("%Y%m%d-%H%M%S", time.localtime(t or time.time()))
    
    
    #################
    # Logging
    #################
    import logging
    from logging.handlers import TimedRotatingFileHandler
    logging.basicConfig(format="[ %(asctime)s][%(module)s.%(funcName)s] %(message)s")
    
    DEFAULT_LEVEL = logging.INFO
    DEFAULT_LOGGING_DIR = osp.join("logs", "gcforest")
    fh = None
    
    
    def init_fh():
        global fh
        if fh is not None:
            return
        if DEFAULT_LOGGING_DIR is None:
            return
        if not osp.exists(DEFAULT_LOGGING_DIR): os.makedirs(DEFAULT_LOGGING_DIR)
        logging_path = osp.join(DEFAULT_LOGGING_DIR, strftime() + ".log")
        fh = logging.FileHandler(logging_path)
        fh.setFormatter(logging.Formatter("[ %(asctime)s][%(module)s.%(funcName)s] %(message)s"))
    
    
    def update_default_level(defalut_level):
        global DEFAULT_LEVEL
        DEFAULT_LEVEL = defalut_level
    
    
    def update_default_logging_dir(default_logging_dir):
        global DEFAULT_LOGGING_DIR
        DEFAULT_LOGGING_DIR = default_logging_dir
    
    
    def get_logger(name="HandWrittenPractice", level=None):
        level = level or DEFAULT_LEVEL
        logger = logging.getLogger(name)
        logger.setLevel(level)
        init_fh()
        if fh is not None:
            logger.addHandler(fh)
        return logger

    Train

    python chinese_write_detection.py --mode=train --max_steps=200000 --eval_steps=1000 --save_steps=10000

    Validation

    python chinese_write_detection.py --mode=validation

    Inference

    python chinese_write_detection.py --mode=inference
  • 相关阅读:
    Javascript实现图片的预加载的完整实现
    python模块查找机制探究
    网络协议模拟之QQ微博分享接口应用
    每周一荐:差异利器Beyond Compare
    Asp.Net MVC4 入门介绍
    单元测试一例:学习断言、测试用例函数的编写
    开源一个网络框架
    .NET服务端持续输出信息到客户端
    Python服务器改造
    CYQ.Data 数据框架 V4.0
  • 原文地址:https://www.cnblogs.com/gmhappy/p/9472395.html
Copyright © 2011-2022 走看看