zoukankan      html  css  js  c++  java
  • SSD 单发多框检测

    其实现在用的最多的是faster rcnn,等下再弄项目~~~

    • 图像经过基础网络块,三个减半模块,每个减半模块由两个二维卷积层,加一个maxPool减半(通道数依次增加【16,32,64】)
    • 然后是多个(3个)多尺度特征块。每个特征块依次都是一个减半模块,通道数固定128
    • 最后一个全局最大池化层模块,高宽降到1
    • 注意,每次添加一个模块,后面都有两个预测层,一个类比预测层,一个边框预测层。类别预测层是一个二维卷积层,卷积层通道数是 锚框*(类别+1) ,然后用不改变图像大小的卷积核3*3 ,padding = 1;边框预测层类似,通道数改为 锚框 * 4 
    %matplotlib inline
    import gluonbook as gb
    from mxnet import autograd,gluon,image,init,nd,contrib
    from mxnet.gluon import loss as gloss,nn
    import time
    
    # 类别预测层
    def cls_predictor(num_anchors,num_classes):
        return nn.Conv2D(num_anchors*(num_classes+1),kernel_size=3,padding=1)
    
    # 边框预测层
    def bbox_predictor(num_anchors):
        return nn.Conv2D(num_anchors*4,kernel_size=3,padding=1)
    
    # 连结多尺度
    def forward(x,block):
        block.initialize()
        return block(x)
    Y1 = forward(nd.zeros((2,8,20,20)),cls_predictor(5,10))
    Y2 = forward(nd.zeros((2,16,10,10)),cls_predictor(3,10))
    
    Y1.shape,Y2.shape
    
    def flatten_pred(pred):
        return pred.transpose((0,2,3,1)).flatten()
    
    def concat_preds(preds):
        return nd.concat(*[flatten_pred(p) for p in preds],dim=1)
    
    concat_preds([Y1,Y2]).shape
    
    # 减半模块
    def down_sample_blk(num_channels):
        blk = nn.Sequential()
        for _ in range(2):
            blk.add(nn.Conv2D(num_channels,kernel_size=3,padding=1),
                   nn.BatchNorm(in_channels=num_channels),
                   nn.Activation('relu'))
        blk.add(nn.MaxPool2D(2))
        return blk
    
    blk = down_sample_blk(10)
    blk.initialize()
    x = nd.zeros((2,3,20,20))
    y = blk(x)
    y.shape
    
    # 主体网络块
    def base_net():
        blk = nn.Sequential()
        for num_filters in [16,32,64]:
            blk.add(down_sample_blk(num_filters))
        return blk
    bnet = base_net()
    bnet.initialize()
    x = nd.random.uniform(shape=(2,3,256,256))
    y = bnet(x)
    y.shape
    
    # 完整的模型
    def get_blk(i):
        if i==0:              # 0 基础网络模块
            blk = base_net()
        elif i==4:            #  4 全局最大池化层模块,将高宽降到1
            blk = nn.GlobalMaxPool2D()
        else:                 # 1 ,2 ,3 高宽减半模块
            blk = down_sample_blk(128)
        return blk
    
    def blk_forward(X,blk,size,ratio,cls_predictor,bbox_predictor):
        Y = blk(X)
        anchors = contrib.nd.MultiBoxPrior(Y,sizes=size,ratios=ratio)
        cls_preds = cls_predictor(Y)
        bbox_preds = bbox_predictor(Y)
        return (Y, anchors, cls_preds,bbox_preds)
    
    sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],
             [0.88, 0.961]]
    ratios = [[1, 2, 0.5]] * 5
    num_anchors = len(sizes[0]) + len(ratios[0]) - 1
    
    
    # 完整的TinySSD
    class TinySSD(nn.Block):
        def __init__(self, num_classes, **kwargs):
            super(TinySSD, self).__init__(**kwargs)
            self.num_classes = num_classes
            for i in range(5):
                # 赋值语句 self.blk_i = get_blk(i)
                setattr(self, 'blk_%d' % i,get_blk(i))
                setattr(self, 'cls_%d' % i,cls_predictor(num_anchors,num_classes))
                setattr(self, 'bbox_%d' % i,bbox_predictor(num_anchors))
                
        def forward(self, X):
            anchors, cls_preds, bbox_preds = [None]*5,[None]*5,[None]*5
            for i in range(5):
                # getattr(self, 'blk_%d' % i ) 即访问 self.blk_i
                X, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(
                    X, getattr(self, 'blk_%d' % i), sizes[i], ratios[i],
                    getattr(self, 'cls_%d' % i), getattr(self, 'bbox_%d' % i))
                
            return (nd.concat(*anchors, dim=1),
                    concat_preds(cls_preds).reshape(
                        (0, -1, self.num_classes + 1)), concat_preds(bbox_preds))
    
    # 测试形状
    net = TinySSD(num_classes=1)
    net.initialize()
    X = nd.zeros((32,3,256,256))
    anchors, cls_preds, bbox_preds = net(X)
    print('output anchors:',anchors.shape)
    print('output class preds:',cls_preds.shape)
    print('output bbox preds:',bbox_preds.shape)
  • 相关阅读:
    【转】IOS缓存机制详解
    Soul网关插件之Sofa
    Soul网关代理Dubbo插件的使用
    Soul网关默认Divide插件的使用
    高性能网关Soul源码调试环境搭建
    大厂面试系列一些问题的解答
    大厂面试系列一些问题的答案
    大厂面试系列(十三):Java基础
    大厂面试助手(十二):场景和设计
    Action Filters for ASP.NET MVC
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10148449.html
Copyright © 2011-2022 走看看