zoukankan      html  css  js  c++  java
  • CTPN-自然文本场景检测代码阅读笔记

    TensorFlow代码 https://github.com/eragonruan/text-detection-ctpn

    CTPN网络结构理解:

    知乎链接:https://zhuanlan.zhihu.com/p/34757009

    训练 main/train.py

    1. utils/prepare/split_label.py

    • 缩放图片resize image(长宽 最大1200,最小600)
    • label处理  将大矩形框label划分一个个16*16的小矩形

    2. 输入

    • input_image 原图像 [[1, H, W, 3]]
    • bbox(GT) [[x_min, y_min, x_max, y_max, 1], […], …]
    • im_info(GT) 图像的高,宽,通道(二维ndarray) [[h,w,c]]

    3. 模型 model_train.py -> model()

    • 图像去均值: mean_image_subtraction(均值设为means=[123.68, 116.78, 103.94])
    • 目的:图像标准化,移除共同部分,凸显个体差异。
    • 输入到 VGG16, conv5 -> [N, H/16, W/16, 512]
    • conv2d -> [N, H/16, W/16, 512]
    • BLSTM -> [N, H/16, W/16, 512]
    • FC -> bbox_pred + cls_pred + cls_prob-> [N, H/16, W/16, 410] + [N, H/16, W/16, 2x10] + [N, H/16, W/16, 210]

    4. 损失 model_train.py -> loss()

    • 生成anchor分类标签和bounding-box回归目标 anchor_target_layer()
    • 输入: cls_pred, bbox(GT), im_info(GT)
    • 返回: [rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights]
    • 方法:
    • 生成基本的anchor(10个),每个anchor对应的四个坐标 [x_min, y_min, x_max, y_max] -> heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] widths = [16]
    • 生成所有的anchor(H/16xW/16x10): 生成feature-map和真实image上anchor之间的偏移量
    • 仅保留那些还在图像内部的anchor,超出图像的都删掉
    • rpn_labels 生成标签(>0.7或者最大的为正标签,<0.3的为负标签),限制标签的数量(总共256个) (先给正的上标签还是先给负的上标签?)
    • rpn_bbox_targets 根据anchor和gtbox计算得真值(anchor和gtbox之间的偏差)
    • 把超出图像范围的anchor再加回来
    • 计算分类损失
    • rpn_cross_entropy_n = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=rpn_label, logits=rpn_cls_score)
      rpn_cross_entropy = tf.reduce_mean(rpn_cross_entropy_n)

    • 计算回归损失
    • rpn_loss_box_n = tf.reduce_sum(rpn_bbox_outside_weights * smooth_l1_dist(rpn_bbox_inside_weights * (rpn_bbox_pred - rpn_bbox_targets)), reduction_indices=[1])
      rpn_loss_box = tf.reduce_sum(rpn_loss_box_n) / (tf.reduce_sum(tf.cast(fg_keep, tf.float32)) + 1)

    • smooth_L1_Loss层理解
      • smooth_L1_Loss是Faster RCNN提出来的计算距离的loss
      • 输入四个bottom,分别是predict,target,inside_weight,outside_weight。与论文并不完全一致,代码中实现的是更加general的版本,公式为:
    • python实现:
      def smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights, sigma=1.0, dim=[1]):
          '''
          bbox_pred   :预测框
          bbox_targets:标签框
          bbox_inside_weights:
          bbox_outside_weights:
          '''  
          sigma_2 = sigma ** 2
          box_diff = bbox_pred - bbox_targets
          in_box_diff = bbox_inside_weights * box_diff
          abs_in_box_diff = tf.abs(in_box_diff)
          # tf.less 返回 True or False; a<b,返回True, 否则返回False。
          smoothL1_sign = tf.stop_gradient(tf.to_float(tf.less(abs_in_box_diff, 1. / sigma_2)))
          # 实现公式中的条件分支
          in_loss_box = tf.pow(in_box_diff, 2) * (sigma_2 / 2.) * smoothL1_sign + (abs_in_box_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign)
          out_loss_box = bbox_outside_weights * in_loss_box
          loss_box = tf.reduce_mean(tf.reduce_sum(out_loss_box, axis=dim))
          return loss_box
    • Smooth L1 Loss相比于L2 Loss对于离群点(outliers)更不敏感(Fast R-CNN中的解释:L1 loss that is less sensitive to outliers than the L2 loss used in R-CNN and SPPnet)。更详细的解释是当预测值与目标值相差很大时,L2 Loss的梯度为(x-t),容易产生梯度爆炸,L1 Loss的梯度为常数,通过使用Smooth L1 Loss,在预测值与目标值相差较大时,由L2 Loss转为L1 Loss可以防止梯度爆炸。 
    • 计算正则损失
    • regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    • 模型损失
    • model_loss = rpn_cross_entropy + rpn_loss_box
    • 总损失
    • total_loss = tf.add_n(regularization_losses) + model_loss

    5. AdamOptimizer() 回归损失

    预测 main/demo.py

    1. 输入

    • input_image [1, H, W, 3]
    • input_im_info [[H, W, C]]

    2. 缩放图片:600x1200

    3. 使用训练好的模型得出 bbox_pred, cls_pred, cls_prob

    4. proposal_layer() 生成propsal

    • 输入:cls_prob, bbox_pred, im_info
    • 返回:textsegs (1 x H x W x A, 5) e.g. [0, x1, y1, x2, y2]
    • 方法:
      生成基本的anchor
      生成整张图像所有的anchor
      根据anchor和bbox_pred,做逆变换,得到box在图像上的真实坐标
      将所有的proposal修建一下,超出图像范围的将会被修剪掉
      移除高度或宽度小于阈值的proposal
      根据分数排序所有的proposal, 进行nms
      输出所有proposal以及分数

    5. TextDetector() 文本检测

      文本线构造算法

    • 输入: textsegs, score[:, np.newaxis], im_info[:2]
    • 输出: 文本行坐标
    • 方法:
    • 删除得分较低的proposal, 阈值0.7
    • 按得分排序
    • 对proposal做nms
    • 文本行的构建(两种方式:水平矩形框和有角度的矩形框)
    • textdetector = TextDetector(DETECT_MODE=‘O’) # DETECT_MODE可以是’O’或者’H’
    • 输出[xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax, score]

    疑问以及代码问题:

    2. rpn_bbox_outside_weights和rpn_bbox_inside_weights用来做什么的
    CTPN中只需要回归proposal的y, h,而bbox_pred的输出为x, y, w, h,所以设置inside_weights=[0, 1, 0, 1]只计算y和h的损失; outside_weights来控制哪些样本参与计算回归损失
    4. 代码中 config.py 中应该为 RPN_BBOX_INSIDE_WEIGHTS = (0.0, 1.0, 0.0, 1.0)

    参考转载:https://blog.csdn.net/m0_38007695/article/details/88699219

  • 相关阅读:
    【常用配置】Spring框架web.xml通用配置
    3.从AbstractQueuedSynchronizer(AQS)说起(2)——共享模式的锁获取与释放
    2.从AbstractQueuedSynchronizer(AQS)说起(1)——独占模式的锁获取与释放
    1.有关线程、并发的基本概念
    0.Java并发包系列开篇
    SpringMVC——DispatcherServlet的IoC容器(Web应用的IoC容器的子容器)创建过程
    关于String的问题
    Spring——Web应用中的IoC容器创建(WebApplicationContext根应用上下文的创建过程)
    <<、>>、>>>移位操作
    System.arraycopy(src, srcPos, dest, destPos, length) 与 Arrays.copyOf(original, newLength)区别
  • 原文地址:https://www.cnblogs.com/lzq116/p/11933625.html
Copyright © 2011-2022 走看看