zoukankan      html  css  js  c++  java
  • MnasNet:经典轻量级神经网络搜索方法 | CVPR 2019

    论文提出了移动端的神经网络架构搜索方法,该方法主要有两个思路,首先使用多目标优化方法将模型在实际设备上的耗时融入搜索中,然后使用分解的层次搜索空间,来让网络保持层多样性的同时,搜索空间依然很简洁,能够使得搜索的模型在准确率和耗时中有更好的trade off

    来源:【晓飞的算法工程笔记】 公众号

    论文: MnasNet: Platform-Aware Neural Architecture Search for Mobile

    Introduction


      在设计移动端卷积网络时,常常会面临着速度与准确率的取舍问题,为了设计更好的移动端卷积网络,论文提出移动网络的神经网络架构搜索方法,大概步骤如图1所示。对比之前的方法,该方法主贡献有3点:

    • 将设计问题转化为多目标优化问题(multi-objective optimization),同时考虑准确率和实际推理耗时。由于计算量FLOPS其实和实际推理耗时并不总是一致的(MobileNet,575M,113ms vs NASNet,564 M,183ms),所以论文通过实际移动设备上运行来测量推理耗时
    • 之前的搜索方法大都先搜索最优的单元,然后堆叠成网络,虽然这样能优化搜索空间,但抑制了层多样性。为了解决这个问题,论文提出分解的层次搜索空间(factorized hierarchical search space),使得层能存在结构差异的同时,仍然很好地平衡灵活性和搜索空间大小

    • 在符合移动端使用的前提下,达到ImageNet和COCO的SOTA,且速度更快,模型更轻量。如图2所示,在准确率更高的前提下,MansNet速度比MobieNet和NASNet-A分别快1.8倍和2.3倍

    Problem Formulation


      对于模型$m$,$ACC(m)$为模型准确率,$LAT(m)$为目标移动平台的推理耗时,$T$为目标耗时,公式1为在符合耗时前提下,最大化准确率

      但公式1仅最优准确率,没有进行多目标优化(multiple Pareto optimal),于是论文改用公式2的加权乘积方法来近似进行多目标优化

      $w$是权重因子,$alpha$和$eta$为应用特定常数(application-specific constants),这两个值的设定是用来保证符合accuracy-latency trade-offs的有相似的reward,即高准确率稍高耗时和稍低准确率低耗时有相同的reward。例如,凭经验认为两倍耗时通常能带来5%准确率提升,对于模型M1(耗时$l$,准确率$a$),模型M2(耗时$2l$,准确率$a(1+5%)$),他们应该有相同的reward:$Reward(M2)=acdot (1+5%)cdot (2l/T)^etaapprox Reward(M1)=acdot (l/T)^eta$,得到$eta=-0.07$。后面实验没说明都使用$alpha=eta=-0.07$

      图3为不同常数下的目标函数曲线,上图$(alpha=0,eta=-1)$意味着符合耗时的直接输出准确率,超过耗时的则大力惩罚,下图$(alpha=eta=-0.07)$则是将耗时作为软约束,平滑地调整目标函数

    Mobile Neural Architecture Search


    Factorized Hierarchical Search Space

      论文提出分别的层次搜索空间,整体构造如图4所示,将卷积神经网络模型分解成独立的块(block),逐步降低块的输入以及增加块中的卷积核数。每个块进行独立块搜索,每个块包含多个相同的层,由块搜索来决定。搜索的目的是基于输入和输出的大小,选择最合适的算子以及参数(kernal size, filter size)来达到更好的accurate-latency trade-off

      每个块的子搜索包含上面6个步骤,例如图4中的block 4,每层都为inverted bottleneck 5x5 convolution和residual skip path,共$N_4$层

      搜索空间选择使用MobileNetV2作为参考,图4的block数与MobileNetV2对应,MobileNetV2的结构如上。在MobileNetV2的基础上,每个block的layer数量进行${0,+1,-1}$进行加减,而卷积核数则选择${0.75,1.0,1.25}$
      论文提出的分解的层次搜索空间对于平衡层多样性和搜索空间大小有特别的好处,假设共$B$blocks,每个block的子搜索空间大小为$S$,平均每个block有$N$层,总共的搜索空间则为$SB$,对比按层搜索的空间$S{B*N}$小了很多

    Search Algorithm

      论文使用NAS的强化学习方法来优化公式2的rewadr期望,在每一轮,controller根据当前参数$ heta$一批模型,每个模型$m$训练后获得准确率$ACC(m)$以及实际推理耗时$LAT(m)$,根据公式2得到reward,然后使用Proximal Policy Optimization来更新controller的参数$ heta$最大化公式5

    Experimental Setup


      论文先尝试在CIFAR-10上进行架构搜索,然后迁移到大数据集上,但是发现这样不奏效,因为考虑了实际耗时,而应用到大数据集时,网络通常需要放大,耗时就不准确了。因此,论文直接在ImageNet上进行搜索,但每个模型只训练5轮来加速。RNN controller与NASNet保持一致,总共需要64 TPUv2搜索4.5天,每个模型使用Pixel 1手机进行耗时测试,最终大概测试了8K个模型,分别选择了top 15和top 1模型进行完整的ImageNet训练以及COCO迁移,输入图片的分辨率分别为$224 imes 224$和$320 imes 320$

    Results


    ImageNet Classification Performance

      $T=75ms$,$alpha=eta=-0.07$,结果如Table 1所示,MnasNet比MobileNetV2(1.4)快1.8倍,准0.5%,比NASNet-A快2.3倍,准1.2%,而稍大的模型MnasNet-A3比ResNet-50准,但少用了4.8x参数和10x计算量

      由于之前的方法没有使用SE模块,论文补充了个对比训练,MnasNet效果依然比之前的方法要好

    Model Scaling Performance

      缩放模型是调整准确率和耗时的来适应不同设备的常见操作,可以使用depth multiplier(好像叫width multiplier?)来缩放每层的channels数,也可以直接降低输入图片的分辨率。从图5可以看到,MansNet始终保持着比MobileNetV2好的表现

      此外,论文提出的方法能够搜索不同耗时的模型,为了比较性能,论文对比了缩放模型和搜索模型的准确率。从Table4看出,搜索出来的模型有更好的准确率

    COCO Object Detection Performance

      论文对比了MnasNet在COCO上的表现,可以看到MnasNet准确率更高,且少用了7.4x参数和42x计算量

    Ablation Study and Discussion


    Soft vs. Hard Latency Constraint

      多目标搜索方法允许通过设定$alpha$和$eta$进行hard和soft的耗时约束,图6展示了$(alpha=0,eta=-1)$和$(alpha=eta=-0.07)$,目标耗时为75ms,可以看到soft搜索更广的区域,构建了很多接近75ms耗时的模型,也构建了更多小于40ms和大于110ms的模型

    Disentangling Search Space and Reward

      论文将多目标优化和分解的层次搜索空间进行对比实验,从结果来看,多目标优化能很好平衡低耗和准确率,而论文提出的搜索空间能同时降低耗时和提高准确率

    MnasNet Architecture and Layer Diversity

      图7(a)为MnasNet-A1的结构,包含了不同的层结构,可以看到该网络同时使用了5x5和3x3的卷积,之前的方法都只使用了3x3卷积

      Table 6展示了MansNet模型及其变体,变体上仅用某一层的来构建网络,可以看到MnasNet在准确率和耗时上有了更好的trade-off

    CONCLUSION


      论文提出了移动端的神经网络架构搜索方法,该方法使用多目标优化方法将模型在实际设备上的耗时融入搜索中,能够使得搜索的模型在准确率和耗时中有更好的trade off。另外该方法使用分解的层次搜索空间,来让网络保持层多样性的同时,搜索空间依然很简洁,也提高了搜索网络的准确率。从实验结果来看,论文搜索到的网络MansNet在准确率和耗时上都比目前的人工构建网络和自动搜索网络要好



    如果本文对你有帮助,麻烦点个赞或在看呗~
    更多内容请关注 微信公众号【晓飞的算法工程笔记】

    work-life balance.

  • 相关阅读:
    PHP
    PHP
    PHP
    网站页面引导操作
    Solr与Tomcat的整合
    POI操作文档内容
    HashTable和HashMap的区别
    ArrayList、LinkedList、HashMap底层实现
    正则表达式语法
    Java并发编程:线程间通信wait、notify
  • 原文地址:https://www.cnblogs.com/VincentLee/p/13299108.html
Copyright © 2011-2022 走看看