zoukankan      html  css  js  c++  java
  • crnn pytorch 训练、测试

    1.仓库地址

    https://github.com/meijieru/crnn.pytorch
    原版用lua实现的:https://github.com/bgshih/crnn
    需要用到的warp_ctc_pytorch: https://github.com/SeanNaren/warp-ctc

    2.环境安装

    普通的环境都可以吧,我是cuda10.0,torch1.2.0 python3.6. 其他环境也应该可以。
    然后库缺少什么就安装什么 pip install ***

    warp-CTC需要编译

    git clone https://github.com/SeanNaren/warp-ctc.git
    cd warp-ctc
    mkdir build; cd build
    cmake ..
    make
    cd ../pytorch_binding
    python setup.py install
    

    我就是这么没有报错就ok
    测试是否安装成功就进入python
    import warpctc_pytorch
    没有报错就说明成功

    3.数据准备,lmdb制作


    需要这么放置,图片和文本放在一个文件夹,文本名和图片名字一样,文本里面内容是图片上文字。
    运行https://github.com/wuzuowuyou/crnn_pytorch/blob/master/myfile/create_lmdb.py脚本
    这里注意需要python2运行。我用Python3运行各种报错什么编码问题,用py2跑一点报错都没有,python2也需要装lmdb,(pip2 install lmdb)
    跑成功会自动生成这两个东东
    ./lmdb/data.mdb
    ./lmdb/lock.mdb
    把lmdb文件夹放在data目录下面。

    4. 训练

    python train.py --adadelta --trainRoot ./data/lmdb/ --valRoot ./data/lmdb/ --cuda

    这里注意一下,如果有大小写,需要改下字典表
    train.py line32
    parser.add_argument('--alphabet', type=str, default='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ')

    5.报错解决

    各种报错啊
    5.1 trainRoot,valRoot需要改下大小写
    5.2 TypeError: Won't implicitly convert Unicode to bytes; use .encode()
    按照错误提示加上encode
    txn.get('num-samples'.encode())
    label_byte = txn.get(label_key.encode())
    imgbuf = txn.get(img_key.encode())
    5.3
    text, _ = self.encode(text)
    File "/home/crnn.pytorch/utils.py", line 45, in encode
    for char in text
    File "/home/crnn.pytorch/utils.py", line 45, in
    for char in text
    KeyError: 'b'
    解决方案:
    dataset.py line 61
    label = str(txn.get(label_key)) ->
    label_byte=txn.get(label_key.encode())
    label = label_byte.decode()

    5.4 raise ValueError('sampler option is mutually exclusive with '
    ValueError: sampler option is mutually exclusive with shuffle
    大意就是sampler和shuffle互斥
    我加了 and 0 不用sample
    if not opt.random_sample and 0:

    5.5 在验证的时候还报错,
    Start val
    Traceback (most recent call last):
    File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 219, in
    val(crnn, test_dataset, criterion)
    File "/data_2/project_2021/crnn/crnn.pytorch-master/train.py", line 168, in val
    preds = preds.squeeze(2)
    IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
    我不验证,加and 0:
    if i % opt.valInterval == 0 and 0:
    val(crnn, test_dataset, criterion)

    错误解决了,然后就可以训练,打印如下:

      (relu6): ReLU(inplace=True)
      )
      (rnn): Sequential(
        (0): BidirectionalLSTM(
          (rnn): LSTM(512, 256, bidirectional=True)
          (embedding): Linear(in_features=512, out_features=256, bias=True)
        )
        (1): BidirectionalLSTM(
          (rnn): LSTM(256, 256, bidirectional=True)
          (embedding): Linear(in_features=512, out_features=63, bias=True)
        )
      )
    )
    [0/100000000][1/9] Loss: 8.430408
    [0/100000000][2/9] Loss: 20.137066
    [0/100000000][3/9] Loss: 25.239346
    [0/100000000][4/9] Loss: 21.249365
    [0/100000000][5/9] Loss: 20.604660
    [0/100000000][6/9] Loss: 14.782236
    

    6.测试 demo.py

    需要改下这里,和训练的时候一致
    model = crnn.CRNN(32, 1, 37, 256)

    报错
    File "/data_2/project_2021/crnn/crnn.pytorch-master/demo_show.py", line 42, in
    model.load_state_dict(torch.load(model_path))
    File "/data_1/Yang/software_install/Anaconda1105/envs/CenterNet_1.0_3.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
    self.class.name, " ".join(error_msgs)))
    RuntimeError: Error(s) in loading state_dict for CRNN:
    Missing key(s) in state_dict: "cnn.conv0.weight", "cnn.conv0.bias", "cnn.conv1.weight", "cnn.conv1.bias", "cnn.conv2.weight", "cnn.conv2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.bias", "cnn.batchnorm2.running_mean", "cnn.batchnorm2.running_var", "cnn.conv3.weight", "cnn.conv3.bias", "cnn.conv4.weight", "cnn.conv4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.bias", "cnn.batchnorm4.running_mean", "cnn.batchnorm4.running_var", "cnn.conv5.weight", "cnn.conv5.bias", "cnn.conv6.weight", "cnn.conv6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.bias", "cnn.batchnorm6.running_mean", "cnn.batchnorm6.running_var", "rnn.0.rnn.weight_ih_l0", "rnn.0.rnn.weight_hh_l0", "rnn.0.rnn.bias_ih_l0", "rnn.0.rnn.bias_hh_l0", "rnn.0.rnn.weight_ih_l0_reverse", "rnn.0.rnn.weight_hh_l0_reverse", "rnn.0.rnn.bias_ih_l0_reverse", "rnn.0.rnn.bias_hh_l0_reverse", "rnn.0.embedding.weight", "rnn.0.embedding.bias", "rnn.1.rnn.weight_ih_l0", "rnn.1.rnn.weight_hh_l0", "rnn.1.rnn.bias_ih_l0", "rnn.1.rnn.bias_hh_l0", "rnn.1.rnn.weight_ih_l0_reverse", "rnn.1.rnn.weight_hh_l0_reverse", "rnn.1.rnn.bias_ih_l0_reverse", "rnn.1.rnn.bias_hh_l0_reverse", "rnn.1.embedding.weight", "rnn.1.embedding.bias".
    Unexpected key(s) in state_dict: "module.cnn.conv0.weight", "module.cnn.conv0.bias", "module.cnn.conv1.weight", "module.cnn.conv1.bias", "module.cnn.conv2.weight", "module.cnn.conv2.bias", "module.cnn.batchnorm2.weight", "module.cnn.batchnorm2.bias", "module.cnn.batchnorm2.running_mean", "module.cnn.batchnorm2.running_var", "module.cnn.batchnorm2.num_batches_tracked", "module.cnn.conv3.weight", "module.cnn.conv3.bias", "module.cnn.conv4.weight", "module.cnn.conv4.bias", "module.cnn.batchnorm4.weight", "module.cnn.batchnorm4.bias", "module.cnn.batchnorm4.running_mean", "module.cnn.batchnorm4.running_var", "module.cnn.batchnorm4.num_batches_tracked", "module.cnn.conv5.weight", "module.cnn.conv5.bias", "module.cnn.conv6.weight", "module.cnn.conv6.bias", "module.cnn.batchnorm6.weight", "module.cnn.batchnorm6.bias", "module.cnn.batchnorm6.running_mean", "module.cnn.batchnorm6.running_var", "module.cnn.batchnorm6.num_batches_tracked", "module.rnn.0.rnn.weight_ih_l0", "module.rnn.0.rnn.weight_hh_l0", "module.rnn.0.rnn.bias_ih_l0", "module.rnn.0.rnn.bias_hh_l0", "module.rnn.0.rnn.weight_ih_l0_reverse", "module.rnn.0.rnn.weight_hh_l0_reverse", "module.rnn.0.rnn.bias_ih_l0_reverse", "module.rnn.0.rnn.bias_hh_l0_reverse", "module.rnn.0.embedding.weight", "module.rnn.0.embedding.bias", "module.rnn.1.rnn.weight_ih_l0", "module.rnn.1.rnn.weight_hh_l0", "module.rnn.1.rnn.bias_ih_l0", "module.rnn.1.rnn.bias_hh_l0", "module.rnn.1.rnn.weight_ih_l0_reverse", "module.rnn.1.rnn.weight_hh_l0_reverse", "module.rnn.1.rnn.bias_ih_l0_reverse", "module.rnn.1.rnn.bias_hh_l0_reverse", "module.rnn.1.embedding.weight", "module.rnn.1.embedding.bias".

    Process finished with exit code 1

    原因在于我们保存的pth权重名字多了module.去掉就好。
    需要改成如下:

    nclass = len(alphabet) + 1
    
    model = crnn.CRNN(32, 1, nclass, 256)#model = crnn.CRNN(32, 1, 37, 256)
    if torch.cuda.is_available():
        model = model.cuda()
    
    #
    # for m in model.state_dict().keys():
    #      print("==:: ", m)
    
    load_model_ = torch.load(model_path)
    # for k, v in load_model_.items():
    #     print(k,"  ::shape",v.shape)
    
    state_dict_rename = collections.OrderedDict()
    for k, v in load_model_.items():
        name = k[7:] # remove `module.`
        state_dict_rename[name] = v
    
    
    print('loading pretrained model from %s' % model_path)
    model.load_state_dict(state_dict_rename)
    

    然后就可以测试了.
    改动太多了,我把改好的代码上传git,有需要的下载。其中,放了10张测试图片和label,可以完成转lmdb。
    https://github.com/wuzuowuyou/crnn_pytorch

    好记性不如烂键盘---点滴、积累、进步!
  • 相关阅读:
    Node.js 笔记03
    Node.js 笔记02
    Node.js 笔记01
    源代码管理工具-git
    ES6笔记01
    07_查找、软链接、打包压缩、软件安装
    06_系统信息相关命令
    oracle序列中cache和nocache
    PL/SQL规范、块、过程、函数、包、触发器
    对Xcode菜单选项的详细探索(干货)
  • 原文地址:https://www.cnblogs.com/yanghailin/p/14519525.html
Copyright © 2011-2022 走看看