zoukankan      html  css  js  c++  java
  • 【DAIS】2020-arxiv-DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search-论文阅读

    DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search

    2020-arxiv-DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search

    来源:ChenBong 博客园

    • Institute:Peking University,Didi Chuxing,Northeastern University
    • Author:Yushuo Guan,Kaigui Bian,Zhengping Che,Yanzhi Wang
    • GitHub:/
    • Citation:/

    Introduction

    image-20201124183013197

    给每个 output channel 附加一个辅助参数 α,用于生成该 channel 的 indicator(其实就是mask,0~1)

    设计了 3 个训练损失项,分别用于:

    1. indicator 约束,使 indicator 稀疏化(二值化)
    2. FLOPs 约束,约束到目标 FLOPs
    3. Symmetry 约束,专门用于有残差连接的网络,保持残差块的输出通道剪枝率相同

    分为 search stage 和 fine-tune stage, search stage 网络逐步收敛到 compact 结构,随后固定网络结构进行fine-tune

    借鉴NAS中可微分的方法(DARTS),使用梯度下降更新 channel 的 indicator

    设计 annealing function(温度系数 t),使得随着训练(搜索)的进行,indicator 根据FLOPs要求,逐步收敛到onehot(0 or 1),即可得到 pruned model

    Motivation

    1. 之前的工作 【CNN-FCF (Li et al. 2019)】,也有使用二值化的 channel indicator 来决定是否剪掉某个 channel,但不可微,需要额外的优化工具(如ADMM)
    2. 之前的自动化通道剪枝【TAS (Dong and Yang 2019)】采用在spernet中搜索 pruned model,但存在候选剪枝子网和 spuernet 之间的 gap 的问题 &&

    Contribution

    1. 在搜索阶段,使用 gradient-based bi-level optimization(和DARTs类似),使用梯度下降交替更新网络权重W 和 辅助参数 α
    2. 在搜索阶段,设计 annealing function,使得 indicator 逐步收敛到0/1,得到 pruned model
    3. 在搜索阶段,设计了3种约束

    Method

    bi-level optimization 优化目标

    image-20201124183128721

    • W,α 交替更新
    • 在训练集上更新W,在验证集上更新 α,避免对训练集的过拟合
      • refer:
      • image-20201124193640335

    Annealing Indicator 的设计

    (l) 层的第 (i) 个 output channel 的 indicator: ( ilde{I}_{l}^{i}, i inleft[1, c_{l} ight])( ilde{I}_{l}^{i}) 的值由 (α^i_l) 决定,且取值范围在 ( ilde{I}_{l}^{i} in [0,1])

    简单的归一化设计:

    ( ilde{I}_{l}^{i}=frac{1}{1+e^{-alpha_{i}^{i}}} qquad (6))

    加上退火策略的归一化设计:

    ( ilde{I}_{l}^{i}=H_{T}left(alpha_{l}^{i} ight)=frac{1}{1+e^{-alpha_{l}^{i} / T}}, quad I_{l}^{i}=lim _{T ightarrow 0} H_{T}left(alpha_{l}^{i} ight) qquad (7))

    初始时,温度系数T最大,(T=T_0) ;之后逐渐变小,T采用退火策略 (T=T_0/σ(n)) 逐渐趋于0,n为搜索阶段的epoch数

    三种 Regularization

    Lasso regularizer

    image-20201124185237215

    • indicator 的 (l_1 norm)

    Continuous FLOPs estimator regularizer

    image-20201124185141872

    • (FLOP_l=(h×w)×(k^2×c_{in})×c_{out})

    Symmetry regularizer

    • 对 residual block output channel 的修剪,会导致同一个 stage 的 residual block 的 output channel 不匹配
    • 以往的工作:
      • 要么不对 residual block 的 output channel 进行修剪,
      • 要么修剪后导致同一个 stage 的 residual block 的 output channel 不匹配
        • 直接抛弃残差连接
        • 使用 1×1 卷积重新进行通道对齐
      • (其实还有对 同一个 stage 的 residual block 的 output channel 使用相同的剪枝率的办法,类似于下面的Constrained)
    • 这些替代的办法会打断原有的梯度传播,导致梯度爆炸或梯度消失,导致性能下降,作者做了一个实验来证明

    image-20201124201418721

    • Random 指从 ResNet-110 中随机采样子网,如果 residual block 的input channel 和 output channel 不同,则抛弃该 block 的残差连接
    • Constraint 指从 ResNet-110 采样子网,并确保每个 stage 的 residual block 的output channel 相同(可以保留所有残差连接)

    image-20201124185209001

    • 只对有残差连接的网络使用,确保 residual block 的 (c_{in}=c_{out})

    Experiments

    Setup

    Search Stage

    • use (R_{FLOPs}) and (R_{sym}) as the default regularizers
    • Search Epochs
      • CIFAR:Search 100 epochs,按 7:3 划分 train set, val set
      • ImageNet:Search 7508 iterations
    • α initialized: (alpha in mathcal{N}left(1,0.1^{2} ight))
    • (T_0=1, T_n=T_0/σ(n))(σ(n) = 49×n/N_{max}+1)(N_{max}) denotes the total number of training epochs.
    • The weights of (R_{FLOPs}) is 2 and (epsilon) = 0.05.
    • The weights of (R_{sym}) is 0.01 for ResNet-56/110 and 0 otherwise.

    Fine-tuning Stage

    • Train Epochs

      • CIFAR:300 epochs
      • ImageNet:120 epochs
    • cos lr

    CIFAR-10/100

    image-20201124185810730

    image-20201124185929546

    ImageNet

    image-20201124185849020

    • 最后一列加速比是使用PyTorch Mobile在Salaxy S9手机上得到的

    Ablation Study

    search methods

    image-20201124191040891

    • Slimming:使用BN层的值作为 Indicator
    • Random:从 ResNet-110 中随机采样子网,如果 residual block 的input channel 和 output channel 不同,则抛弃该 block 的残差连接
    • Constraint:从 ResNet-110 采样子网,并确保每个 stage 的 residual block 的output channel 相同(可以保留所有残差连接)

    The effectiveness of (R_{FLOPs}) and (R_{sym})

    image-20201124191040891

    • (w/o R_{FLOPs}) : replaces (R_{FLOPs}) by (R_{lasso})
    • (w/o R_{sym}) :removes the symmetry regularizer
      • image-20201124192107185
      • symmetry regularizer 即约束同一个 stage 的 block 的 output channel 相同

    The effectiveness of the annealing-relaxed function

    image-20201124191040891

    • w/o annealing​:removing the annealing-relaxed function
      • 移除退火策略,即使用简单的归一化 ( ilde{I}_{l}^{i}=frac{1}{1+e^{-alpha_{i}^{i}}} qquad (6)) ,这时候 ( ilde{I}_{l}^{i}) 不会收敛到 [0~1] ,需要引入阈值,将低于阈值的 output channel(filter)剪掉: ( ilde{I}_{l}^{i}<0.55)

    The effectiveness of the bi-level optimization

    image-20201124191040891

    • w/o bi-level:同时在训练集上更新 W 和 α

    Robustness of DAIS

    image-20201124191054009

    • the impact of a shorter training scheme:
      • e50:search 50 epochs
    • the impact of different temperature decay scheme:
      • cosine(σ(n) = 49×(1−cos(frac{π}{2}n/N_{max}))+1)
      • smallT(σ(n) = 99×n/N_{max}+1)

    One-shot capability of DAIS

    image-20201124191109538

    • 原始方法的FLOPs剪枝率是在 search stage 全程固定的,search 完以后再进行fine-tune
    • 这里改为在 search stage 逐步提高剪枝率?

    Recoverable Pruning

    image-20201124191123859

    • 被剪掉的channel可能会重新恢复,自我调整能力

    Conclusion

    Summary

    • 可微分的mask
    • 逐渐 onehot 的mask
    • 连续的 FLOPs 估计

    To Read

    Reference

  • 相关阅读:
    C++编程开发学习的50条建议(转)
    编程思想:我现在是这样编程的(转)
    Linux系统编程@多线程与多进程GDB调试
    字符串分割函数 STRTOK & STRTOK_R (转)
    C语言指针与数组的定义与声明易错分析
    C语言 a和&a的区别
    C语言二重指针与malloc
    【C语言入门】C语言的组成结构(基础完整篇)!
    程序员吐槽女友败家:开酒店必须400元起步,工资却不到自己的一半!
    怎样才能和编程语言对上眼?你需要做些准备以及...
  • 原文地址:https://www.cnblogs.com/chenbong/p/14047882.html
Copyright © 2011-2022 走看看