zoukankan      html  css  js  c++  java
  • EnsNet: Ensconce Text in the Wild 模型训练

    参考网址:

    $ https://github.com/HCIILAB/Scene-Text-Removal

    环境配置

    $ git clone https://github.com/HCIILAB/Scene-Text-Removal
    $ https://files.pythonhosted.org/packages/b0/e3/0a7bf93413623ec5a1fa42eb3c89f88731a62155f22ca6b1abc8c67c28d3/mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
    $ pip install mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
    $ pip install mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
    $ 验证是否安装成功
    Python
    >> import mxnet
    import 成功说明没有问题。

    下载数据集

    $ 目的是为了参照它的数据集整理我们自己的数据集。数据格式最好和它的一样,方便先跑通代码。

    模型训练

    $ 目的是为了参照它的数据集整理我们自己的数据集。数据格式最好和它的一样,方便先跑通代码。

    python train.py --trainset_path=’dataset’ --checkpoint=’save_model’ --gpu=0 --lr=0.0002 --n_epoch=5000

    网络调整

    训练这个网络存在的问题是:

    给出来的数据和给的数据读取方式,不匹配,或者至少我没有理解。

    解压以后的数据格式为:

    syn_train下面包含img和label两个文件夹。

    实际读图像的代码如下所示:

    class MyDataSet(Dataset):
    def __init__(self, root, split, is_transform=False,is_train=True):
    self.root = os.path.join(root, split)
    self.is_transform = is_transform
    self.img_paths = []
    self._img_512 = os.path.join(root, split, 'train_512', '{}.png')
    self._mask_512 = os.path.join(root, split, 'mask_512', '{}.png')
    self._lbl_512 = os.path.join(root, split, 'train_512', '{}.png')
    self._img_256 = os.path.join(root, split, 'train_256', '{}.png')
    self._lbl_256 = os.path.join(root, split, 'train_256', '{}.png')
    self._img_128 = os.path.join(root, split, 'train_128', '{}.png')
    for fn in os.listdir(os.path.join(root, split, 'train_512')):
    if len(fn) > 3 and fn[-4:] == '.png':
    self.img_paths.append(fn[:-4])

    def __len__(self):
    return len(self.img_paths)

    def __getitem__(self, idx):
    img_path_512 = self._img_512.format(self.img_paths[idx])
    img_path_256 = self._img_256.format(self.img_paths[idx])
    img_path_128 = self._img_128.format(self.img_paths[idx])
    lbl_path_256 = self._lbl_256.format(self.img_paths[idx])
    mask_path_512 = self._mask_512.format(self.img_paths[idx])
    lbl_path_512 = self._lbl_512.format(self.img_paths[idx])
    img_arr_256 = mx.image.imread(img_path_256).astype(np.float32)/127.5 - 1
    img_arr_512 = mx.image.imread(img_path_512).astype(np.float32)/127.5 - 1
    img_arr_128 = mx.image.imread(img_path_128).astype(np.float32)/127.5 - 1
    img_arr_512 = mx.image.imresize(img_arr_512, img_wd * 2, img_ht)
    img_arr_in_512, img_arr_out_512 = [mx.image.fixed_crop(img_arr_512, 0, 0, img_wd, img_ht),
    mx.image.fixed_crop(img_arr_512, img_wd, 0, img_wd, img_ht)]
    if os.path.exists(mask_path_512):
    mask_512 = mx.image.imread(mask_path_512)
    else:
    mask_512 = mx.image.imread(mask_path_512.replace(".png",'.jpg',1))
    tep_mask_512 = nd.slice_axis(mask_512, axis=2, begin=0, end=1)/255
    if self.is_transform:
    imgs = [img_arr_out_512, img_arr_in_512, tep_mask_512,img_arr_256,img_arr_128]
    imgs = random_horizontal_flip(imgs)
    imgs = random_rotate(imgs)
    img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_256,img_arr_128 = imgs[0], imgs[1], imgs[2], imgs[3],imgs[4]
    img_arr_in_512, img_arr_out_512 = [nd.transpose(img_arr_in_512, (2,0,1)),
    nd.transpose(img_arr_out_512, (2,0,1))]
    img_arr_out_256 = nd.transpose(img_arr_256, (2,0,1))
    img_arr_out_128 = nd.transpose(img_arr_128, (2,0,1))
    tep_mask_512 = tep_mask_512.reshape(tep_mask_512.shape[0],tep_mask_512.shape[1],1)
    tep_mask_512 = nd.transpose(tep_mask_512,(2,0,1))
    return img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_out_256,img_arr_out_128
    不匹配,实际我们没有这么多文件夹。
    排查问题的过程:
  • 相关阅读:
    【转载】这才是真正的表扩展方案
    【转载】啥,又要为表增加一列属性?
    【转载】这才是真正的分布式锁
    mysql备份表sql
    selenium定位当前处于那个iframe(frame)中
    MQ手动推送消息
    报表导出时间格式数据多‘0‘
    python里的原始字符串
    qq邮箱设置授权码方法(jenkins)
    Apache与Tomcat有什么关系和区别(转)
  • 原文地址:https://www.cnblogs.com/wjjcjj/p/12017064.html
Copyright © 2011-2022 走看看