zoukankan      html  css  js  c++  java
  • 【Computer Vision】 复现分割网络(1)——SegNet

    Tags: ComputerVision

    编译

    1. src/caffe/layers/contrastive_loss_layer.cpp:56:30: error: no matching function for call to ‘max(double, float)’
      Dtype dist = std::max(margin - sqrt(dist_sq_.cpu_data()[i]), Dtype(0.0));

    Replace line 56 by this one :
    Dtype dist = std::max(margin - (float)sqrt(dist_sq_.cpu_data()[i]), Dtype(0.0));
    2. .build_release/lib/libcaffe.so: undefined reference to `cv::imread(cv::String const&, int)'

    Change Makefile:
    LIBRARIES += glog gflags protobuf leveldb snappy
    lmdb boost_system hdf5_hl hdf5 m
    opencv_core opencv_highgui opencv_imgproc
    add :opencv_imgcodecs

    数据处理

    1. median frequency balancing的计算
      图片分割经常会遇到class unbalance的情况,如果你的target是要求每个类别的accuracy 都很高那么在训练的时候做class balancing 很重要,如果你的target要求只要求图片总体的pixel accuracy好,那么class balancing 此时就不是很重要,因为占比小的class, accuray 虽然小,但是对总体的Pixel accuracy影响也较小。
      那么看下本文中的meidan frequency balancing是如何计算的:
      对于一个多类别图片数据库,每个类别都会有一个class frequency, 该类别像素数目除以数据库总像素数目, 求出所有class frequency 的median 值,除以该类别对应的frequency 得到weight:

    [weight_i = median(weights)/weight_i ]

    这样可以保证占比小的class, 权重大于1, 占比大的class, 权重小于1, 达到balancing的效果.
    如对我自己的数据有两类分别为0,1, 一共55张500500训练图片,统计55张图片中0,1像素的个数:
    count1 227611
    count0 13522389
    freq1 = 227611/(500
    50055) = 0.0166
    freq0 = 13522389/(500
    500*55) = 0.9834
    median = 0.5
    weight1 = 30.12
    weight0 = 0.508

    1. webdemo权重
      作者训练的webdemo和他给出的模型文件的类别数目和label 是对不上号的,因此可以使用webdemo跑测试,但是最好不要在上面finetune, 直接在VGG-16上面finetune 就行

    2. rgb label 转换为 gray label

    一些数据集给出的label是rgb的,如下图,但是训练过程中输入网络的label一般是0 - class_num-1标记的label map, 因此需要一个转换过程,下面给出一个python2转换脚本:

    #!/usr/bin/env python
    import os
    import numpy as np
    from itertools import izip
    from argparse import ArgumentParser
    from collections import OrderedDict
    from skimage.io import ImageCollection, imsave
    from skimage.transform import resize
    
    
    camvid_colors = OrderedDict([
        ("Animal", np.array([64, 128, 64], dtype=np.uint8)),
        ("Archway", np.array([192, 0, 128], dtype=np.uint8)),
        ("Bicyclist", np.array([0, 128, 192], dtype=np.uint8)),
        ("Bridge", np.array([0, 128, 64], dtype=np.uint8)),
        ("Building", np.array([128, 0, 0], dtype=np.uint8)),
        ("Car", np.array([64, 0, 128], dtype=np.uint8)),
        ("CartLuggagePram", np.array([64, 0, 192], dtype=np.uint8)),
        ("Child", np.array([192, 128, 64], dtype=np.uint8)),
        ("Column_Pole", np.array([192, 192, 128], dtype=np.uint8)),
        ("Fence", np.array([64, 64, 128], dtype=np.uint8)),
        ("LaneMkgsDriv", np.array([128, 0, 192], dtype=np.uint8)),
        ("LaneMkgsNonDriv", np.array([192, 0, 64], dtype=np.uint8)),
        ("Misc_Text", np.array([128, 128, 64], dtype=np.uint8)),
        ("MotorcycleScooter", np.array([192, 0, 192], dtype=np.uint8)),
        ("OtherMoving", np.array([128, 64, 64], dtype=np.uint8)),
        ("ParkingBlock", np.array([64, 192, 128], dtype=np.uint8)),
        ("Pedestrian", np.array([64, 64, 0], dtype=np.uint8)),
        ("Road", np.array([128, 64, 128], dtype=np.uint8)),
        ("RoadShoulder", np.array([128, 128, 192], dtype=np.uint8)),
        ("Sidewalk", np.array([0, 0, 192], dtype=np.uint8)),
        ("SignSymbol", np.array([192, 128, 128], dtype=np.uint8)),
        ("Sky", np.array([128, 128, 128], dtype=np.uint8)),
        ("SUVPickupTruck", np.array([64, 128, 192], dtype=np.uint8)),
        ("TrafficCone", np.array([0, 0, 64], dtype=np.uint8)),
        ("TrafficLight", np.array([0, 64, 64], dtype=np.uint8)),
        ("Train", np.array([192, 64, 128], dtype=np.uint8)),
        ("Tree", np.array([128, 128, 0], dtype=np.uint8)),
        ("Truck_Bus", np.array([192, 128, 192], dtype=np.uint8)),
        ("Tunnel", np.array([64, 0, 64], dtype=np.uint8)),
        ("VegetationMisc", np.array([192, 192, 0], dtype=np.uint8)),
        ("Wall", np.array([64, 192, 0], dtype=np.uint8)),
        ("Void", np.array([0, 0, 0], dtype=np.uint8))
    ])
    
    
    def convert_label_to_grayscale(im):
        out = (np.ones(im.shape[:2]) * 255).astype(np.uint8)
        for gray_val, (label, rgb) in enumerate(camvid_colors.items()):
            match_pxls = np.where((im == np.asarray(rgb)).sum(-1) == 3)
            out[match_pxls] = gray_val
        assert (out != 255).all(), "rounding errors or missing classes in camvid_colors"
        return out.astype(np.uint8)
    
    
    def make_parser():
        parser = ArgumentParser()
        parser.add_argument(
            'label_dir',
            help="Directory containing all RGB camvid label images as PNGs"
        )
        parser.add_argument(
            'out_dir',
            help="""Directory to save grayscale label images.
            Output images have same basename as inputs so be careful not to
            overwrite original RGB labels""")
        return parser
    
    
    if __name__ == '__main__':
        parser = make_parser()
        args = parser.parse_args()
        labs = ImageCollection(os.path.join(args.label_dir, "*"))
        os.makedirs(args.out_dir)
        for i, (inpath, im) in enumerate(izip(labs.files, labs)):
            print(i + 1, "of", len(labs))
            # resize to caffe-segnet input size and preserve label values
            resized_im = (resize(im, (360, 480), order=0) * 255).astype(np.uint8)
            out = convert_label_to_grayscale(resized_im)
            outpath = os.path.join(args.out_dir, os.path.basename(inpath))
            imsave(outpath, out)
    

    训练结果

    基于VGG-16finetune训练的一个模型迭代20000次的测试结果:
    gQZ7n.png
    label:
    gQyPQ.png
    基于VGG-16自己数据训练的结果:
    g4BBu.png
    label:
    g45vH.png

    测试结果:
    g49kN.png

    Reference

    1. Demystifying Segnet:http://5argon.info/portfolio/d/SegnetTrainingGuide.pdf
  • 相关阅读:
    Using NAT between the vCenter Server system and ESXi/ESX hosts (1010652)
    Zabbix监控windows进程连接数
    CentOS yum [Errno 14] problem making ssl connection CentOs
    httpSecurity
    Nginx
    线程基础知识
    SqlServler
    关于数据库索引
    Https的底层原理
    Synchronized
  • 原文地址:https://www.cnblogs.com/vincentcheng/p/9179606.html
Copyright © 2011-2022 走看看