zoukankan      html  css  js  c++  java
  • S2DNAS:北大提出动态推理网络搜索,加速推理,可转换任意网络 | ECCV 2020 Oral

    S2DNAS最核心的点在于设计了丰富而简洁的搜索空间,从而能够使用常规的NAS方法即可进行动态推理网络的搜索,解决了动态推理网络的设计问题,可进行任意目标网络的转换

    来源:晓飞的算法工程笔记 公众号

    论文: S2DNAS: Transforming Static CNN Model for Dynamic Inference via Neural Architecture Search

    Introduction


      最近,动态推理作为提升网络推理速度的有效方法,得到了大量关注。相对于剪枝、量化等静态操作,动态推理能够根据样本的难易程度选择合适的计算图,可以很好地平衡准确率和计算消耗,公众号之前也发过一篇相关的Resolution Adaptive Networks for Efficient Inference,有兴趣可以看看。为了实现动态推理,大多数的工作都需要专门的策略来动态地根据输入样本跳过某些计算操作。

      一种经典的方法上在常规卷积网络上添加中间预测层,如图a所示,当中间预测结果的置信度大于阈值,则提前退出。但早期的分类器没有利用深层的语义特征(低分辨率的高维特征),可能会导致明显的准确率下降。

      为了解决上述问题,MSDNet设计了二维(Layer-Scale)多阶段架构来获取各层的粗粒度特征和细粒度特征,如图b所示,每个预测层都能利用深层的语义特征,可达到较好的准确率。然而,MSDNet是精心设计的专用网络结构,若需要转换其它目标网络,则需要重新设计类似的范式。

      为了解决上述问题且不需要重新设计网络结构,论文提议将目标网络转换成channel-wise多阶段网络,如图c所示。该方法保持目标网络的结构,在channel层面将目标网络分成多个阶段,仅在最后的卷积层添加预测器。为了降低计算量,每个阶段的channel数都相对减少。基于图c的思想,论文提出通用结构S2DNAS,能够自动地将目标网络转换成图c架构的动态网络。

    Overview of S2DNAS


      给定目标网络,S2DNAS的转换流程如图2所示,主要包含两个模块:

    • S2D(static-todynamic),生成目标网络特定的搜索空间,由目标网络通过预设的转换方法生成的多阶段网络组成。
    • NAS,在生成的搜索空间中使用强化学习搜索最优的模型,设计了能够反映准确率和资源消耗的回报函数。

    The Details of S2D


      给定目标网络$mathbb{M}$,S2D生成包含由$mathbb{M}$转换的多个网络的搜索空间$mathcal{Z}$,如图3所示,转换过程包含split操作和concat操作:

    • split操作在channel层面上将目标网络分割成多阶段子网,在每个阶段最后添加分类器。
    • concat操作是为了增加阶段间的特征交互,强制当前阶段的分类器复用前面阶段的某些特征。

    Notation

      首先定义一些符号,$X^{(k)}={ x^{(k)}_1, cdots, x^{(k)}_C }$为第$k$层输入,$C$为输入维度,$W^{k}={ w^{(k)}_1, cdots, w^{(k)}_O }$,$O$为输出维度,$w^{(k)}_iin mathbb{R}^{k_c imes k_c imes C}$,转换操作的目标是将目标网络$mathbb{M}$转换成多阶段网络$a={ f_1,cdots, f_s }$,$f_i$为$i$阶段的分类器。

    Split

      Split操作将输入维度的子集赋予不同阶段的分类器,假设阶段数为$s$,直接的方法将输入维度分成$s$个子集,然后将$i$个子集赋予$i$个分类器,但这样会生成较大的搜索空间,阻碍后续的搜索效率。为了降低搜索空间的复杂度,论文先将输入维度分成多组,然后将组分成$s$个子集赋予不同的分类器。
      具体地,将输入维度分成$G$组,每组包含$m=frac{C}{G}$维,以$k$层为例,分组为$X^{(k)}={ x^{(k)}_1, cdots, x^{(k)}G }$,$X^{(k)}i={ x^{(k)}{(i-1)m+1}, cdots, x^{(k)}{im} }$。当分组完成后,使用分割点$(p^{(k)}_0, p^{(k)}1, cdots, p^{(k)}{s-1}, p{(k)}_s)$标记分组的分配,$p{(k)}0=0$和$p^{(k)}s=G$为两个特殊点,将维度分组${ X{(k)}_{p{(k)}{i-1}+1}, cdots, X{(k)}_{p{(k)}{i}}}$分给$i$阶段的分类器$f_i$。

    Concat

      Concat操作用于增加阶段间的特征交互,使得分类器能够复用前面阶段的特征。指示矩阵${ I{(k)}}L_{k=1}$用来表明不同位置的特征是否复用,$k$为层数,$L$为网络的深度,成员$m^{(k)}{ij} in I{(k)}$表明是否在$j$阶段复用$i$阶段的$k$层特征。这里有两个限制,首先只能复用前面阶段的特征$m{(k)}{ij}=0, j<i, forall k < L$,其次$L$层必须复用前面所有阶段的特征。

    Architecture Search Space

      基于上面的两种转换操作,S2D可以生成包含丰富多阶段网络的搜索空间。不同分割点和指示矩阵有不同的意义,调整分割点能够改变分组特征的分配方式,从而改变不同阶段分类器在准确率和资源消耗上的trade-off,而调整指示矩阵则伴随特征复用策略的改变。为了降低搜索空间的大小,在实验时规定目标网络中相同特征大小的层使用相同的分割点和指示矩阵。

    The Details of NAS


      在生成搜索空间后,下一个步骤就是找到最优的动态网络结构,搜索过程将网络$a$表示为两种转换的设置,并标记$mathcal{Z}$为包含不同设置的空间。论文采用NAS常用的policy gradient based强化学习方法进行搜索,该方法的目标是优化策略$pi$,进而通过策略$pi$得到最优的网络结构,优化过程可公式化为嵌套的优化问题:

      $ heta_a$是网络$a$的权值,$pi$是用来生成转换设置的策略,$mathcal{D_{val}}$和$mathcal{D_{train}}$标记验证集和训练集,$R$为验证多阶段网络性能的回报函数。为了解决公式1的优化问题,需要解决两个子问题,根据$ heta^{*}a$优化策略$pi$和优化网络$a$的$ heta{a}$。

    Optimization of the Transformation Settings

      与之前的NAS方法类似(公众号有很多NAS的论文解读,可以去看看),使用RNN生成目标网络每层的不同转换设置的分布,然后policy gradient based算法会优化RNN的参数来最大化回报函数:

      $ACC(a, heta_a, mathcal{D})$为准确率,$COST(a, heta_a, mathcal{D})$为动态推理的平均资源消耗。为了与其它动态推理研究比较,采用FLOPs表示计算消耗,$w$为平衡准确率和资源消耗的超参数。

    Optimization of the Multi-stage CNN

      使用梯度下降来优化内层的优化问题,修改常规的分类损失函数来适应多阶段模型的训练情况:

      $CE$为交叉熵损失函数,公式3可认为是连续训练不同阶段的分类器,可使用SGD及其变种进行参数$ heta$的优化。为了缩短训练时间,仅用几个训练周期来接近$ heta^{*}$,没有完整地训练网络到收敛。训练完成后,在测试集进行回报函数的测试,优化RNN。最后选择10个搜索过程中最优的网络结构进行完整地训练,选择性能最好的网络结构输出。

    Dynamic Inference of the Searched CNN

      对于最优的多阶段网络$a={f_1, cdots, f_s }$后,在使用时为每个阶段预设一个阈值。按计算图依次进行多阶段推理,当该阶段的预测结果达到阈值时,则停止后续的推理,直接使用该阶段结果。

    Experiments


      与多种类型的加速方法对比不同目标网络的转化后性能。

      与MSDNet进行DenseNet转换性能对比。

      不同目标网络转换后各阶段的性能对比。

      准确率与计算量间的trade-off。

      多阶段ResNet-56在CIFAR-10上的模型。

    Conclustion


      S2DNAS最核心的点在于设计了丰富而简洁的搜索空间,从而能够使用常规的NAS方法即可进行动态推理网络的搜索,解决了动态推理网络的设计问题,可进行任意目标网络的转换。不过S2DNAS没有公布搜索时间,而在采用网络训练的时候仅用少量训练周期,没有列出验证训练方式和完整训练得出的准确率是否有偏差。此外,S2DNAS的核心是将静态网络转换成动态网络,如果转换时能够将静态网络的权值用上,可以更有意义,不然直接在目标数据集上搜索就好了,没必要转换。



    如果本文对你有帮助,麻烦点个赞或在看呗~
    更多内容请关注 微信公众号【晓飞的算法工程笔记】

    work-life balance.

  • 相关阅读:
    Jenkins自动化部署入门详细教程
    单元测试
    弱网测试
    Token、Cookie和Session
    测试开发人员必备Linux命令
    TestNG(一)
    char和varchar
    你平时会看日志吗,一般会出现哪些异常(Exception)
    内存溢出和内存泄漏的区别,产生原因以及解决方案
    测试一个电梯
  • 原文地址:https://www.cnblogs.com/VincentLee/p/13518593.html
Copyright © 2011-2022 走看看