zoukankan      html  css  js  c++  java
  • Python创建CRNN训练用的LMDB数据库文件

    CRNN简介


    CRNN由 Baoguang Shi, Xiang Bai, Cong Yao提出,2015年7月发表论文:“An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition”,链接地址:https://arxiv.org/abs/1507.05717v1


    CRNN(卷积循环神经网络)集成了卷积神经网络(CNN)和循环神经网络(RNN)的优点。CRNN可以直接从序列标签(例如单词,句子)中学习,不需要详细的单个分别标注,并且对图像序列对象的长度无限定,只需要在训练和测试阶段对图像高度做一下归一化。于现有技术相比,CRNN在场景文本识别上表现良好。

    CRNN中训练数据的格式是LMDB,保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key,如下所示:




    准备CRNN训练数据集


    数据集图片是若干带有文字的图片,文字的高度约占图片高度的80%~90%,数据集标签是txt文本格式,文本内容是图片上的文字,文本名字要跟图片名字一致,如123.jpg对应标签需要是123.txt。


    例如有 01.jpg 和 02.jpg 两个样本,标签文件是 01.txt 和 02.txt :





    创建用于CRNN训练的LMDB数据


    # -*- coding: utf-8 -*-
    import os
    import lmdb # install lmdb by "pip install lmdb"
    import cv2
    import numpy as np
    #from genLineText import GenTextImage
    
    def checkImageIsValid(imageBin):
        if imageBin is None:
            return False
        imageBuf = np.fromstring(imageBin, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        if img is None:
            return False
        imgH, imgW = img.shape[0], img.shape[1]
        if imgH * imgW == 0:
            return False
        return True
    
    
    def writeCache(env, cache):
        with env.begin(write=True) as txn:
            for k, v in cache.iteritems():
                txn.put(k, v)
    
    
    def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
        """
        Create LMDB dataset for CRNN training.
    
        ARGS:
            outputPath    : LMDB output path
            imagePathList : list of image path
            labelList     : list of corresponding groundtruth texts
            lexiconList   : (optional) list of lexicon lists
            checkValid    : if true, check the validity of every image
        """
        #print (len(imagePathList) , len(labelList))
        assert(len(imagePathList) == len(labelList))
        nSamples = len(imagePathList)
        print '...................'
        # map_size=1099511627776 定义最大空间是1TB
        env = lmdb.open(outputPath, map_size=1099511627776)
        
        cache = {}
        cnt = 1
        for i in xrange(nSamples):
            imagePath = imagePathList[i]
            label = labelList[i]
            if not os.path.exists(imagePath):
                print('%s does not exist' % imagePath)
                continue
            with open(imagePath, 'r') as f:
                imageBin = f.read()
            if checkValid:
                if not checkImageIsValid(imageBin):
                    print('%s is not a valid image' % imagePath)
                    continue
    
    
            ########## .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
            imageKey = 'image-%09d' % cnt
            labelKey = 'label-%09d' % cnt
            cache[imageKey] = imageBin
            cache[labelKey] = label
            ##########
            if lexiconList:
                lexiconKey = 'lexicon-%09d' % cnt
                cache[lexiconKey] = ' '.join(lexiconList[i])
            if cnt % 1000 == 0:
                writeCache(env, cache)
                cache = {}
                print('Written %d / %d' % (cnt, nSamples))
            cnt += 1
        nSamples = cnt-1
        cache['num-samples'] = str(nSamples)
        writeCache(env, cache)
        print('Created dataset with %d samples' % nSamples)
    
    
    def read_text(path):
        
        with open(path) as f:
            text = f.read()
        text = text.strip()
        
        return text
    
    
    import glob
    if __name__ == '__main__':
        
        #lmdb 输出目录
        outputPath = '../data/lmdb/trainMy'
    
        # 训练图片路径,标签是txt格式,名字跟图片名字要一致,如123.jpg对应标签需要是123.txt
        path = '../data/dataline/*.jpg'
    
        imagePathList = glob.glob(path)
        print '------------',len(imagePathList),'------------'
        imgLabelLists = []
        for p in imagePathList:
            try:
               imgLabelLists.append((p,read_text(p.replace('.jpg','.txt'))))
            except:
                continue
                
        #imgLabelList = [ (p,read_text(p.replace('.jpg','.txt'))) for p in imagePathList]
        ##sort by lebelList
        imgLabelList = sorted(imgLabelLists,key = lambda x:len(x[1]))
        imgPaths = [ p[0] for p in imgLabelList]
        txtLists = [ p[1] for p in imgLabelList]
        
        createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)



    读取LMDB数据集中图片


    # -*- coding: utf-8 -*-
    import numpy as np
    import lmdb
    import cv2
    
    with lmdb.open("../data/lmdb/train") as env:
        txn = env.begin()
        for key, value in txn.cursor():
            print (key,value)
            imageBuf = np.fromstring(value, dtype=np.uint8)
            img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                cv2.imshow('image', img)
                cv2.waitKey()
            else:
                print 'This is a label: {}'.format(value)

  • 相关阅读:
    python测试开发django-rest-framework-87.分页查询之偏移分页(LimitOffsetPagination)和游标分页(CursorPagination)
    python测试开发django-rest-framework-86.分页查询功能(PageNumberPagination)
    python测试开发django-rest-framework-85.序列化(ModelSerializer)之设置必填(required)和非必填字段
    python测试开发django-rest-framework-84.序列化(ModelSerializer)之日期时间格式带T问题
    去掉DELPHI开启后弹出安全警告框
    使用path 格式获取java hashmap key 值
    Kubeapps-2.0 发布了
    monio系统性能分析相关命令
    imgproxy 强大高效的图片处理服务
    nodejs java 互调用
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9411755.html
Copyright © 2011-2022 走看看