DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search
2020-arxiv-DAIS: Automatic Channel Pruning via Differentiable Annealing Indicator Search
来源:ChenBong 博客园
- Institute:Peking University,Didi Chuxing,Northeastern University
- Author:Yushuo Guan,Kaigui Bian,Zhengping Che,Yanzhi Wang
- GitHub:/
- Citation:/
Introduction
给每个 output channel 附加一个辅助参数 α,用于生成该 channel 的 indicator(其实就是mask,0~1)
设计了 3 个训练损失项,分别用于:
- indicator 约束,使 indicator 稀疏化(二值化)
- FLOPs 约束,约束到目标 FLOPs
- Symmetry 约束,专门用于有残差连接的网络,保持残差块的输出通道剪枝率相同
分为 search stage 和 fine-tune stage, search stage 网络逐步收敛到 compact 结构,随后固定网络结构进行fine-tune
借鉴NAS中可微分的方法(DARTS),使用梯度下降更新 channel 的 indicator
设计 annealing function(温度系数 t),使得随着训练(搜索)的进行,indicator 根据FLOPs要求,逐步收敛到onehot(0 or 1),即可得到 pruned model
Motivation
- 之前的工作 【CNN-FCF (Li et al. 2019)】,也有使用二值化的 channel indicator 来决定是否剪掉某个 channel,但不可微,需要额外的优化工具(如ADMM)
- 之前的自动化通道剪枝【TAS (Dong and Yang 2019)】采用在spernet中搜索 pruned model,但存在候选剪枝子网和 spuernet 之间的 gap 的问题 &&
Contribution
- 在搜索阶段,使用 gradient-based bi-level optimization(和DARTs类似),使用梯度下降交替更新网络权重W 和 辅助参数 α
- 在搜索阶段,设计 annealing function,使得 indicator 逐步收敛到0/1,得到 pruned model
- 在搜索阶段,设计了3种约束
Method
bi-level optimization 优化目标
- W,α 交替更新
- 在训练集上更新W,在验证集上更新 α,避免对训练集的过拟合
- refer:
Annealing Indicator 的设计
第 (l) 层的第 (i) 个 output channel 的 indicator: ( ilde{I}_{l}^{i}, i inleft[1, c_{l} ight]) , ( ilde{I}_{l}^{i}) 的值由 (α^i_l) 决定,且取值范围在 ( ilde{I}_{l}^{i} in [0,1])
简单的归一化设计:
( ilde{I}_{l}^{i}=frac{1}{1+e^{-alpha_{i}^{i}}} qquad (6))
加上退火策略的归一化设计:
( ilde{I}_{l}^{i}=H_{T}left(alpha_{l}^{i} ight)=frac{1}{1+e^{-alpha_{l}^{i} / T}}, quad I_{l}^{i}=lim _{T ightarrow 0} H_{T}left(alpha_{l}^{i} ight) qquad (7))
初始时,温度系数T最大,(T=T_0) ;之后逐渐变小,T采用退火策略 (T=T_0/σ(n)) 逐渐趋于0,n为搜索阶段的epoch数
三种 Regularization
Lasso regularizer
- indicator 的 (l_1 norm)
Continuous FLOPs estimator regularizer
- (FLOP_l=(h×w)×(k^2×c_{in})×c_{out})
Symmetry regularizer
- 对 residual block output channel 的修剪,会导致同一个 stage 的 residual block 的 output channel 不匹配
- 以往的工作:
- 要么不对 residual block 的 output channel 进行修剪,
- 要么修剪后导致同一个 stage 的 residual block 的 output channel 不匹配
- 直接抛弃残差连接
- 使用 1×1 卷积重新进行通道对齐
- (其实还有对 同一个 stage 的 residual block 的 output channel 使用相同的剪枝率的办法,类似于下面的Constrained)
- 这些替代的办法会打断原有的梯度传播,导致梯度爆炸或梯度消失,导致性能下降,作者做了一个实验来证明
- Random 指从 ResNet-110 中随机采样子网,如果 residual block 的input channel 和 output channel 不同,则抛弃该 block 的残差连接
- Constraint 指从 ResNet-110 采样子网,并确保每个 stage 的 residual block 的output channel 相同(可以保留所有残差连接)
- 只对有残差连接的网络使用,确保 residual block 的 (c_{in}=c_{out})
Experiments
Setup
Search Stage
- use (R_{FLOPs}) and (R_{sym}) as the default regularizers
- Search Epochs
- CIFAR:Search 100 epochs,按 7:3 划分 train set, val set
- ImageNet:Search 7508 iterations
- α initialized: (alpha in mathcal{N}left(1,0.1^{2} ight))
- (T_0=1, T_n=T_0/σ(n)) ,(σ(n) = 49×n/N_{max}+1) ,(N_{max}) denotes the total number of training epochs.
- The weights of (R_{FLOPs}) is 2 and (epsilon) = 0.05.
- The weights of (R_{sym}) is 0.01 for ResNet-56/110 and 0 otherwise.
Fine-tuning Stage
-
Train Epochs
- CIFAR:300 epochs
- ImageNet:120 epochs
-
cos lr
CIFAR-10/100
ImageNet
- 最后一列加速比是使用PyTorch Mobile在Salaxy S9手机上得到的
Ablation Study
search methods
- Slimming:使用BN层的值作为 Indicator
- Random:从 ResNet-110 中随机采样子网,如果 residual block 的input channel 和 output channel 不同,则抛弃该 block 的残差连接
- Constraint:从 ResNet-110 采样子网,并确保每个 stage 的 residual block 的output channel 相同(可以保留所有残差连接)
The effectiveness of (R_{FLOPs}) and (R_{sym})
- (w/o R_{FLOPs}) : replaces (R_{FLOPs}) by (R_{lasso})
- (w/o R_{sym}) :removes the symmetry regularizer
- symmetry regularizer 即约束同一个 stage 的 block 的 output channel 相同
The effectiveness of the annealing-relaxed function
- w/o annealing:removing the annealing-relaxed function
- 移除退火策略,即使用简单的归一化 ( ilde{I}_{l}^{i}=frac{1}{1+e^{-alpha_{i}^{i}}} qquad (6)) ,这时候 ( ilde{I}_{l}^{i}) 不会收敛到 [0~1] ,需要引入阈值,将低于阈值的 output channel(filter)剪掉: ( ilde{I}_{l}^{i}<0.55)
The effectiveness of the bi-level optimization
- w/o bi-level:同时在训练集上更新 W 和 α
Robustness of DAIS
- the impact of a shorter training scheme:
- e50:search 50 epochs
- the impact of different temperature decay scheme:
- cosine: (σ(n) = 49×(1−cos(frac{π}{2}n/N_{max}))+1)
- smallT: (σ(n) = 99×n/N_{max}+1)
One-shot capability of DAIS
- 原始方法的FLOPs剪枝率是在 search stage 全程固定的,search 完以后再进行fine-tune
- 这里改为在 search stage 逐步提高剪枝率?
Recoverable Pruning
- 被剪掉的channel可能会重新恢复,自我调整能力
Conclusion
Summary
- 可微分的mask
- 逐渐 onehot 的mask
- 连续的 FLOPs 估计