zoukankan      html  css  js  c++  java
  • Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks

    https://github.com/youzhonghui/gate-decorator-pruning

    1.utils.py

    class dotdict(dict):
        """dot.notation access to dictionary attributes"""
        __getattr__ = dict.get
        __setattr__ = dict.__setitem__
        __delattr__ = dict.__delitem__

    继承dict字典,实际上还是dict

    2.loader/__init__.py

    import torchvision
    from torch.utils.data import DataLoader
    from torchvision import transforms
    
    from PIL import Image
    from config import cfg
    
    from loader.cifar10 import get_cifar10
    from loader.cifar100 import get_cifar100
    from loader.imagenet import get_imagenet
    
    def get_loader():
        pair = { # 根据设置的参数的数据集名字来选择使用哪个数据集
            'cifar10': get_cifar10,
            'cifar100': get_cifar100,
            'imagenet': get_imagenet
        }
    
        return pair[cfg.data.type]()

    选择使用那个数据集,对应的config设置为:

    from config import parse_from_dict
    parse_from_dict({
    ...
        "data": {
            "type": "cifar10", #这个即使用的数据集名字
            "shuffle": True,
            "batch_size": 128,
            "test_batch_size": 128,
            "num_workers": 4
    ...

    3.models/__init__.py

    import torch
    from config import cfg
    
    def get_vgg16_for_cifar():
        from models.cifar.vgg import VGG
        return VGG('VGG16', cfg.model.num_class)
    
    def get_resnet50_for_imagenet():
        from models.imagenet.resnet50 import Resnet50
        return Resnet50(cfg.model.num_class)
    
    def get_resnet56():
        from models.cifar.resnet56 import resnet56
        return resnet56(cfg.model.num_class)
    
    def get_model():
        pair = {# 根据设置的参数的模型名字来选择使用哪个模型
            'cifar.vgg16': get_vgg16_for_cifar,
            'resnet50': get_resnet50_for_imagenet,
            'cifar.resnet56': get_resnet56
        }
    
        model = pair[cfg.model.name]()
    
        if cfg.base.checkpoint_path != '': #是否有训练好的预训练模型
            print('restore checkpoint: ' + cfg.base.checkpoint_path)
            model.load_state_dict(torch.load(cfg.base.checkpoint_path, map_location='cpu' if not cfg.base.cuda else 'cuda'))
    
        if cfg.base.cuda: #单个GPU
            model = model.cuda()
    
        if cfg.base.multi_gpus: #多个GPU
            model = torch.nn.DataParallel(model)
        return model

    选择使用哪个模型进行分类,并设置是使用cpu还是GPU,有预训练模型就加载预训练模型

    对应的config设置为:

    from config import parse_from_dict
    parse_from_dict({
    ...
        "model": {
            "name": "cifar.resnet56",
            "num_class": 10,
            "pretrained": False
        },

    4.loss.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
    import numpy as np
    from config import cfg
    
    def get_criterion():
        pair = { #设置使用的损失函数
            'softmax': nn.CrossEntropyLoss()
        }
    
        assert (cfg.loss.criterion in pair)
        criterion = pair[cfg.loss.criterion]
        return criterion

    使用交叉熵损失函数

    对应的config设置为:

    from config import parse_from_dict
    parse_from_dict({
    ...
        "loss": {
            "criterion": "softmax"

    5. config.py

    import argparse
    import json
    from utils import dotdict
    
    def make_as_dotdict(obj): #从dict改成dotdict格式
        if type(obj) is dict:
            obj = dotdict(obj)
            for key in obj:
                if type(obj[key]) is dict:
                    obj[key] = make_as_dotdict(obj[key])
        return obj
    
    def parse():
        print('Parsing config file...')
        parser = argparse.ArgumentParser(description="config")
        parser.add_argument(
            "--config",
            type=str,
            default="configs/base.json",
            help="Configuration file to use"
        )
        cli_args = parser.parse_args()
    
        with open(cli_args.config) as fp:
            config = make_as_dotdict(json.loads(fp.read()))
        print(json.dumps(config, indent=4, sort_keys=True))
        return config
    
    class Singleton(object):
        _instance = None
        def __new__(cls, *args, **kw):
            if not cls._instance:
                cls._instance = super(Singleton, cls).__new__(cls, *args, **kw)  
            return cls._instance 
    
    class Config(Singleton):
        def __init__(self):
            self._cfg = dotdict({})
            try:
                self._cfg = parse()
            except:
                pass
    
        def __getattr__(self, name):
            if name == '_cfg':
                super().__setattr__(name)
            else:
                return self._cfg.__getattr__(name)
    
        def __setattr__(self, name, val):
            if name == '_cfg':
                super().__setattr__(name, val)
            else:
                self._cfg.__setattr__(name, val)
    
        def __delattr__(self, name): #del删除元素时调用
            return self._cfg.__delitem__(name)
    
        def copy(self, new_config):
            self._cfg = make_as_dotdict(new_config)
    
    cfg = Config()
    
    def parse_from_dict(d): #将dict换成dotdict
        global cfg
        assert type(d) == dict
        cfg.copy(d)

    设置参数

    但是不太明白为什么要弄成dotdict格式

    这个函数在后面进行prune和finetune的时候会调用来设置参数信息,如:

    from config import parse_from_dict
    parse_from_dict({
        "base": {
            "task_name": "resnet56_cifar10_ticktock",
            "cuda": True,
            "seed": 0,
            "checkpoint_path": "",
            "epoch": 0,
            "multi_gpus": True,
            "fp16": False
        },
        "model": {
            "name": "cifar.resnet56",
            "num_class": 10,
            "pretrained": False
        },
        "train": {
            "trainer": "normal",
            "max_epoch": 160,
            "optim": "sgd",
            "steplr": [
                [80, 0.1], #step>=80时,学习率都设置为0.1
                [120, 0.01],
                [160, 0.001] # 120<step<=160时将学习率设置为0.001
            ],
            "weight_decay": 5e-4,
            "momentum": 0.9,
            "nesterov": False
        },
        "data": {
            "type": "cifar10",
            "shuffle": True,
            "batch_size": 128,
            "test_batch_size": 128,
            "num_workers": 4
        },
        "loss": {
            "criterion": "softmax"
        },
        "gbn": {
            "sparse_lambda": 1e-3,
            "flops_eta": 0,
            "lr_min": 1e-3,
            "lr_max": 1e-2,
            "tock_epoch": 10,
            "T": 10,
            "p": 0.002
        }
    })
    from config import cfg

    6.

    trainer/__init__.py

    from trainer.normal import NormalTrainer
    from config import cfg
    
    def get_trainer():
        pair = {
            'normal': NormalTrainer
        }
        assert (cfg.train.trainer in pair)
    
        return pair[cfg.train.trainer]()

     设置使用的训练train()、测试test()函数所在位置

    扩展:

    #coding:utf-8
    import torch
    if __name__ == '__main__':
        a = torch.FloatTensor([[3, 14, 15, 13], [5,4,15,7]]).t()
        b = torch.FloatTensor([[3, 3, 3, 3], [5,5,5,5]]).t()
        correct = a.eq(b)
        print(correct)
    
        print(correct[:1])
        print(correct[:1].view(-1))
        print(correct[:1].view(-1).float())
        correct_k = correct[:1].view(-1).float().sum(0, keepdim=True)
        print(correct_k)

    返回:

    tensor([[ True,  True],
            [False, False],
            [False, False],
            [False, False]])
    tensor([[True, True]])
    tensor([True, True])
    tensor([1., 1.])
    tensor([2.])

    trainer/normal.py:

    from time import time
    
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import torch.nn.functional as F
    
    from tqdm import tqdm
    import numpy as np
    from config import cfg
    
    FINISH_SIGNAL = 'finish'
    
    def accuracy(output, target, topk=(1,)): #计算分类的准确度
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk) #看是top-1还是top-5
            batch_size = target.size(0)
    
            _, pred = output.topk(maxk, 1, True, True) # 从输出中得到前maxk个大的预测结果的索引值,大小为(batch_size, maxk)
            pred = pred.t() # 转置成(maxk, batch_size)
            # target从(batch_size, 1) -> (1, batch_size) -> (maxk, batch_size)
            # 然后与pred对比看是否相等,每个batch_size最多只有一个相等,所以correct中true的个数最大值为batch_size
            # correct为(maxk, batch_size),值为
            correct = pred.eq(target.view(1, -1).expand_as(pred)) 
    
            res = []
            for k in topk:
                correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size)) # 得到百分比的正确率
            return res
    
    class NormalTrainer():
        def __init__(self):
            self.use_cuda = cfg.base.cuda
    
        def test(self, pack, topk=(1,)): # 测试
            pack.net.eval()
            loss_acc, correct, total = 0.0, 0.0, 0.0
            hub = [[] for i in range(len(topk))]
    
            for data, target in pack.test_loader:
                if self.use_cuda:
                    data, target = data.cuda(), target.cuda()
    
                with torch.no_grad(): #不后向传播
                    output = pack.net(data)
                    loss_acc += pack.criterion(output, target).data.item() #计算损失
                    acc = accuracy(output, target, topk) # 准确率
                    for acc_idx, score in enumerate(acc): 
                        hub[acc_idx].append(score[0].item())
    
            loss_acc /= len(pack.test_loader) # 最后得到的平均损失
            info = {
                'test_loss': loss_acc
            }
            
            for acc_idx, k in enumerate(topk):
                info['acc@%d' % k] = np.mean(hub[acc_idx]) #top-1,top-5等准确率
    
            return info
    
        def train(self, pack, loss_hook=None, iter_hook=None, update=True, mute=False, acc_step=1): #训练,mute即是否打印info
            pack.net.train()
            loss_acc, correct_acc, total = 0.0, 0.0, 0.0
            begin = time()
    
            pack.optimizer.zero_grad()
            with tqdm(total=len(pack.train_loader), disable=mute) as pbar:
                total_iter = len(pack.train_loader) #总迭代次数
                for cur_iter, (data, label) in enumerate(pack.train_loader):
                    if iter_hook is not None:
                        signal = iter_hook(cur_iter, total_iter)
                        if signal == FINISH_SIGNAL: #结束标志
                            break
                    if self.use_cuda:
                        data, label = data.cuda(), label.cuda()
                    data = Variable(data, requires_grad=False)
                    label = Variable(label)
    
                    logits = pack.net(data)
                    loss = pack.criterion(logits, label)
                    if loss_hook is not None:
                        additional = loss_hook(data, label, logits)
                        loss += additional
                    loss = loss / acc_step
                    loss.backward()
    
                    if (cur_iter + 1) % acc_step == 0:
                        if update:
                            pack.optimizer.step()
                        pack.optimizer.zero_grad()
    
                    loss_acc += loss.item()
                    pbar.update(1)
    
            info = {
                'train_loss': loss_acc / len(pack.train_loader),
                'epoch_time': time() - begin
            }
            return info

    一个train()即跑完一次所有数据就结束了,即enumerate(pack.train_loader)完就完了

    7.main.py

    """
     * Copyright (C) 2019 Zhonghui You
     * If you are using this code in your research, please cite the paper:
     * Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks, in NeurIPS 2019.
    """
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    import numpy as np
    import random
    import math
    
    from loader import get_loader
    from models import get_model
    from trainer import get_trainer
    from loss import get_criterion
    
    from utils import dotdict
    from config import cfg
    from logger import logger
    
    
    def _sgdr(epoch):
        lr_min, lr_max = cfg.train.sgdr.lr_min, cfg.train.sgdr.lr_max
        restart_period = cfg.train.sgdr.restart_period
        _epoch = epoch - cfg.train.sgdr.warm_up
    
        while _epoch/restart_period > 1.:
            _epoch = _epoch - restart_period
            restart_period = restart_period * 2.
    
        radians = math.pi*(_epoch/restart_period)
        return lr_min + (lr_max - lr_min) *  0.5*(1.0 + math.cos(radians))
    
    def _step_lr(epoch):
        v = 0.0
        for max_e, lr_v in cfg.train.steplr: #max_e是到这个step的学习率都是lr_v
            v = lr_v
            if epoch <= max_e:
                break
        return v
    
    def get_lr_func():
        if cfg.train.steplr is not None:
            return _step_lr
        elif cfg.train.sgdr is not None:
            return _sgdr
        else:
            assert False
    
    def adjust_learning_rate(epoch, pack): #设置使用的优化器,并设置学习率调节函数,以及更新学习率
        if pack.optimizer is None:
            if cfg.train.optim == 'sgd' or cfg.train.optim is None:
                pack.optimizer = optim.SGD(
                    pack.net.parameters(),
                    lr=1,
                    momentum=cfg.train.momentum,
                    weight_decay=cfg.train.weight_decay,
                    nesterov=cfg.train.nesterov
                )
            else:
                print('WRONG OPTIM SETTING!')
                assert False
            pack.lr_scheduler = optim.lr_scheduler.LambdaLR(pack.optimizer, get_lr_func())
    
        pack.lr_scheduler.step(epoch)
        return pack.lr_scheduler.get_lr()
    
    def recover_pack():
        train_loader, test_loader = get_loader()
    
        pack = dotdict({
            'net': get_model(),
            'train_loader': train_loader,
            'test_loader': test_loader,
            'trainer': get_trainer(),
            'criterion': get_criterion(),
            'optimizer': None,
            'lr_scheduler': None
        })
    
        adjust_learning_rate(cfg.base.epoch, pack)
        return pack
    
    def set_seeds(): #用来保证代码中随机数每次都一样
        torch.manual_seed(cfg.base.seed)
        if cfg.base.cuda:
            torch.cuda.manual_seed_all(cfg.base.seed)
            torch.backends.cudnn.deterministic = True
            if cfg.base.fp16:
                torch.backends.cudnn.enabled = True
                # torch.backends.cudnn.benchmark = True
        np.random.seed(cfg.base.seed)
        random.seed(cfg.base.seed)
    
    
    def main():
        set_seeds() #设置中设置的"seed": 0,就是用在这的
        pack = recover_pack() #设置各个参数和使用的模型、数据等
    
        for epoch in range(cfg.base.epoch + 1, cfg.train.max_epoch + 1):
            lr = adjust_learning_rate(epoch, pack) # 更新lr
            info = pack.trainer.train(pack) #训练模型,得到损失和准确率等信息
            info.update(pack.trainer.test(pack)) #加入测试时的损失和准确率等信息
            info.update({'LR': lr}) #记录此时的lr
            print(epoch, info)
            logger.save_record(epoch, info) #写入日志
            if epoch % cfg.base.model_saving_interval == 0:
                logger.save_network(epoch, pack.net) # 保存网络
    
    if __name__ == '__main__':
        main()

    8.logger.py

    import torch
    
    from config import cfg
    import os
    import json
    import numpy as np
    
    
    class MetricsRecorder():
        def __init__(self):
            self.rec = {}
    
        def add(self, pairs):
            for key, val in pairs.items():
                if key not in self.rec:
                    self.rec[key] = []
                self.rec[key].append(val)
    
        def mean(self):
            r = {}
            for key, val in self.rec.items():
                r[key] = np.mean(val)
            return r
    
    class Logger():
        def __init__(self):
            self.base_path = './logs/' + cfg.base.task_name
            self.logfile = self.base_path + '/log.json'
            self.cfgfile = self.base_path + '/cfg.json'
    
            if not os.path.isdir(self.base_path):
                os.makedirs(self.base_path, exist_ok=True)
                with open(self.logfile, 'w') as fp:
                    json.dump({}, fp) #初始化时日志信息为空
                with open(self.cfgfile, 'w') as fp:
                    json.dump(cfg, fp) #初始化时配置信息即config信息
    
        def save_record(self, epoch, record): #保存运行过程中训练和测试的损失和准确率等信息,以当前的epoch为索引
            with open(self.logfile) as fp:
                log = json.load(fp)
    
            log[str(epoch)] = record
            with open(self.logfile, 'w') as fp:
                json.dump(log, fp)
    
        def save_network(self, epoch, network):
            saving_path = self.base_path + '/ckp.%d.torch' % epoch
            print('saving model ...')
            if type(network) is torch.nn.DataParallel:
                torch.save(network.module.state_dict(), saving_path)
            else:
                torch.save(network.state_dict(), saving_path)
    
            cfg.base.epoch = epoch
            cfg.base.checkpoint_path = saving_path
            with open(self.cfgfile, 'w') as fp: # 保存新的配置信息
                json.dump(cfg, fp)
    
    logger = None
    if logger is None:
        logger = Logger()

     相应的文件将会根据任务名字,即设置:

    from config import parse_from_dict
    parse_from_dict({
        "base": {
            "task_name": "resnet18", #任务名字

    在./logs文件夹下创建同名文件夹存储log.jsoncfg.json文件,save_record()就是将中间信息保存在这,调用save_network()也会将模型保存在该文件夹中

    user@jiayuan:/opt/.../gate-decorator-pruning/logs/resnet18$ ls
    cfg.json  log.json

    接下来就是prune和finetune了,重要

    9.prune/utils.py

    #coding:utf-8
    import os
    
    if __name__ == '__main__':
        print(os.devnull) #/dev/null

    代码:

    import torch
    import torch.nn as nn
    
    import os, contextlib
    from thop import profile
    
    def analyse_model(net, inputs):
        # silence
        with open(os.devnull, 'w') as devnull: #os.devnull对于Linux为/dev/null
            with contextlib.redirect_stdout(devnull):#标准输出已经重定向到了 /dev/null
                flops, params = profile(net, (inputs, )) #估算PyTorch模型的FLOPs模块
        return flops, params
    
    
    def finetune(pack, lr_min, lr_max, T, mute=False): #T即finetune_epoch,即40轮迭代
        logs = []
        epoch = 0
    
        def iter_hook(curr_iter, total_iter): #作为train的iter_hook参数传入
            total = T * total_iter #total_iter即dataloader中有多少批batch_size,所以整个finetune跑total个batch_size
            half = total / 2
            itered = epoch * total_iter + curr_iter #curr_iter即一个epoch中,数据跑到了第curr_iter个batch_size,现在的总batch_size数为itered
            if itered < half: #当小于一半时,学习率这么算
                _iter = epoch * total_iter + curr_iter
                _lr = (1- _iter / half) * lr_min + (_iter / half) * lr_max
            else: # 当大于或等于一半时,学习率这么算,这两个的差别就是lr_max和lr_min的前后位置不同,大概意思是相同的
                _iter = (epoch - T/2) * total_iter + curr_iter
                _lr = (1- _iter / half) * lr_max + (_iter / half) * lr_min
    
            for g in pack.optimizer.param_groups:
                g['lr'] = max(_lr, 0)
                g['momentum'] = 0.0
    
        for i in range(T): #训练40个epoch
            info = pack.trainer.train(pack, iter_hook = iter_hook)
            info.update(pack.trainer.test(pack))
            info.update({'LR': pack.optimizer.param_groups[0]['lr']})
            epoch += 1
            if not mute: #是否打印损失和精确度等信息
                print(info)
            logs.append(info)
    
        return logs

     这里的微调操作其实就是论文中的:

     它跟tick-tock中的tock的差别在于tock中使用的还是GBN,且训练次数比较少,一半就10次;而finetune操作是在整个模型都prune后的小模型中训练,GBN都换回BN,且训练次数也比较多

    10.prune/universal.py

    扩展

    1)uuid库:

    UUID: 通用唯一标识符 ( Universally Unique Identifier ), 对于所有的UUID它可以保证在空间和时间上的唯一性. 它是通过MAC地址, 时间戳, 命名空间, 随机数, 伪随机数来保证生成ID的唯一性, 有着固定的大小( 128 bit ).  它的唯一性和一致性特点使得可以无需注册过程就能够产生一个新的UUID. UUID可以被用作多种用途, 既可以用来短时间内标记一个对象, 也可以可靠的辨别网络中的持久性对象.

      为什么要使用UUID?

      很多应用场景需要一个id, 但是又不要求这个id 有具体的意义, 仅仅用来标识一个对象. 常见的例子有数据库表的id 字段. 另一个例子是前端的各种UI库, 因为它们通常需要动态创建各种UI元素, 这些元素需要唯一的id , 这时候就需要使用UUID了.

    #coding:utf-8
    import uuid
    
    if __name__ == '__main__':
        print(uuid.uuid1()) #7b24099a-27ae-11ea-b076-00e04c6841ff

     其实这个库主要是用于像resnet这样的网络中有侧枝shortcut的情况,是分组使用的,即同一个Group的ID是相同的。像VGG这样的网络每个GBN层的ID是不同的

    2)nn.Parameter

    #coding:utf-8
    import torch.nn as nn
    import torch
    
    if __name__ == '__main__':
        g = nn.Parameter(torch.ones(1, 3, 1, 1), requires_grad=True)
        print(g)

    返回:

    Parameter containing:
    tensor([[[[1.]],
    
             [[1.]],
    
             [[1.]]]], requires_grad=True)

    使用nn.Parameter的目的是将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个值就变成了模型的一部分,成为了模型中根据训练可以改动的参数了


    模型中的bias和weight都是nn.Parameter,可用于训练,并实现优化;Variable则是作为模型的输入

    buffers()返回一个模块缓冲区的迭代器,其保存的是模型中每次前向传播需用到上一次前向传播的结果,作为持久状态的值,如BatchNorm2d()中使用的均值和方差值,其随着BatchNorm2d()中参数的变化而变化

    3)

     所以GatedBatchNorm2d代码中初始化中有设置参数:

        def extract_from_bn(self):
            # freeze bn weight
            with torch.no_grad():
                self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10)) #将self.bn.bias / self.bn.weight的值保持在[-10, 10],小于-10的即改为-10,大于10的即改为10
                self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
                self.bn.weight.set_(torch.ones_like(self.bn.weight)) #torch.ones_like(input)相当于torch.ones(input.size())
                self.bn.weight.requires_grad = False

    Φ就是g,β就是bn.bias,γ就是self.bn.weight

    prune后,得到应该截取掉的filter,变回来的代码:

        def melt(self):
            with torch.no_grad():
                mask = self.bn_mask.view(-1) #转成列表, mask中有channels个值,值为0说明该channel被prune了
                replacer = nn.BatchNorm2d(int(self.bn_mask.sum())).to(self.bn.weight.device)
                replacer.running_var.set_(self.bn.running_var[mask != 0]) #BatchNorm2d中的方差
                replacer.running_mean.set_(self.bn.running_mean[mask != 0]) #BatchNorm2d中的均值
                replacer.weight.set_((self.bn.weight * self.g.view(-1))[mask != 0])
                replacer.bias.set_((self.bn.bias * self.g.view(-1))[mask != 0])
            return replacer

    整个代码:

    import torch
    import torch.nn as nn
    
    import numpy as np
    import uuid
    
    OBSERVE_TIMES = 5
    FINISH_SIGNAL = 'finish'
    
    class Meltable(nn.Module):
        def __init__(self):
            super(Meltable, self).__init__()
    
        @classmethod
        def melt_all(cls, net):
            def _melt(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        _melt(modules[k]._modules)
                    if isinstance(modules[k], Meltable):
                        modules[k] = modules[k].melt()
    
            _melt(net._modules)
    
        @classmethod
        def observe(cls, pack, lr):
            tmp = pack.train_loader
            if pack.tick_trainset is not None:
                pack.train_loader = pack.tick_trainset #数据集
    
            for m in pack.net.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.weight.data.abs_().add_(1e-3)
    
            def replace_relu(modules): #将relu函数换成LeakyReLU函数
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        replace_relu(modules[k]._modules)
                    if isinstance(modules[k], nn.ReLU):
                        modules[k] = nn.LeakyReLU(inplace=True)
            replace_relu(pack.net._modules)
    
            count = 0
            def _freeze_bn(curr_iter, total_iter):
                for m in pack.net.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()
                nonlocal count
                count += 1
                if count == OBSERVE_TIMES:
                    return FINISH_SIGNAL
            info = pack.trainer.train(pack, iter_hook=_freeze_bn, update=False, mute=True) #优化器不优化了
    
            def recover_relu(modules): #将LeakyReLU函数换成relu函数
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        recover_relu(modules[k]._modules)
                    if isinstance(modules[k], nn.LeakyReLU):
                        modules[k] = nn.ReLU(inplace=True)
            recover_relu(pack.net._modules)
    
            for m in pack.net.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.weight.data.abs_().add_(-1e-3) # 变回来
    
            pack.train_loader = tmp
    
    
    class GatedBatchNorm2d(Meltable):
        def __init__(self, bn, minimal_ratio = 0.1):
            super(GatedBatchNorm2d, self).__init__()
            assert isinstance(bn, nn.BatchNorm2d)
            self.bn = bn
            self.group_id = uuid.uuid1()
    
            self.channel_size = bn.weight.shape[0]
            self.minimal_filter = max(1, int(self.channel_size * minimal_ratio)) #最小的通道数
            self.device = bn.weight.device
            self._hook = None
    
            self.g = nn.Parameter(torch.ones(1, self.channel_size, 1, 1).to(self.device), requires_grad=True)#一个可以用于训练的参数
            # 这样后就会生成三个参数self.areaself.scoreself.bn_mask
            self.register_buffer('area', torch.zeros(1).to(self.device)) #即nn.Module.register_buffer,保存一些前向传播会用到的上一次前向传播的结果
            self.register_buffer('score', torch.zeros(1, self.channel_size, 1, 1).to(self.device))
            self.register_buffer('bn_mask', torch.ones(1, self.channel_size, 1, 1).to(self.device))
            #bn_mask就是用来记录该bn的channels层是否被prune了,为0则被prune了,初始化为1
            self.extract_from_bn() #将本身的bn的weight、bias和g三个参数重新设置一下
    
        def set_groupid(self, new_id):
            self.group_id = new_id
    
        def extra_repr(self): #即prune后channel数从channel_size变为了bn_mask.sum()
            return '%d -> %d | ID: %s' % (self.channel_size, int(self.bn_mask.sum()), self.group_id)
    
        def extract_from_bn(self):
            # freeze bn weight
            with torch.no_grad():
                self.bn.bias.set_(torch.clamp(self.bn.bias / self.bn.weight, -10, 10)) #将self.bn.bias / self.bn.weight的值保持在[-10, 10],小于-10的即改为-10,大于10的即改为10
                self.g.set_(self.g * self.bn.weight.view(1, -1, 1, 1))
                self.bn.weight.set_(torch.ones_like(self.bn.weight)) #torch.ones_like(input)相当于torch.ones(input.size())
                self.bn.weight.requires_grad = False
    
        def reset_score(self):
            self.score.zero_()
    
        def cal_score(self, grad):
            # used for hook
            self.score += (grad * self.g).abs() #论文中公式6的计算,计算分数,即变成prune后的网络和以前网络的损失差计算,得到此时设置的每个参数g的分数,分数越小直至0说明该g的channels能删除
    
        def start_collecting_scores(self):
            if self._hook is not None:
                self._hook.remove()
    
            self._hook = self.g.register_hook(self.cal_score) #后向传播计算出关于这个参数g的gradient后将会调用cal_score计算此时的self.score分数值,排序使用
    
        def stop_collecting_scores(self):
            if self._hook is not None:
                self._hook.remove() # 移除register_hook得到的hook
                self._hook = None
        
        def get_score(self, eta=0.0):
            # use self.bn_mask.sum() to calculate the number of input channel. eta should had been normed
            # 因为self.bn_mask中的值都是1,大小为torch.ones(1, self.channel_size, 1, 1),所以sum()后的结果为self.channel_size
            flops_reg = eta * int(self.area[0]) * self.bn_mask.sum()
            return ((self.score - flops_reg) * self.bn_mask).view(-1)
    
        def forward(self, x):
            x = self.bn(x) * self.g # self.g就是用来排重要性的参数
    
            self.area[0] = x.shape[-1] * x.shape[-2] #长*宽=面积area
    
            if self.bn_mask is not None:
                return x * self.bn_mask
            return x
    
        def melt(self):
            with torch.no_grad():
                mask = self.bn_mask.view(-1) #得到当前prune后的channels数
                replacer = nn.BatchNorm2d(int(self.bn_mask.sum())).to(self.bn.weight.device)
                replacer.running_var.set_(self.bn.running_var[mask != 0]) #BatchNorm2d中的方差
                replacer.running_mean.set_(self.bn.running_mean[mask != 0]) #BatchNorm2d中的均值
                replacer.weight.set_((self.bn.weight * self.g.view(-1))[mask != 0])
                replacer.bias.set_((self.bn.bias * self.g.view(-1))[mask != 0])
            return replacer
    
        @classmethod
        def transform(cls, net, minimal_ratio=0.1):
            r = []
            def _inject(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        _inject(modules[k]._modules)
                    if isinstance(modules[k], nn.BatchNorm2d): # 将模型中的nn.BatchNorm2d换成GatedBatchNorm2d,截取后的filter数量>= max(1, int(self.channel_size * minimal_ratio))
                        modules[k] = GatedBatchNorm2d(modules[k], minimal_ratio)
                        r.append(modules[k])
            _inject(net._modules)
            return r

    4)

    卷积层的prune:

    class Conv2dObserver(Meltable):
        def __init__(self, conv):
            super(Conv2dObserver, self).__init__()
            assert isinstance(conv, nn.Conv2d)
            self.conv = conv
            self.in_mask = torch.zeros(conv.in_channels).to('cpu')
            self.out_mask = torch.zeros(conv.out_channels).to('cpu')
            self.f_hook = conv.register_forward_hook(self._forward_hook) #该层卷机前向传播是进行的操作
    
        def extra_repr(self):
            return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum()))
        
        def _forward_hook(self, m, _in, _out):
            x = _in[0] #self.in_mask就是用来记录该channels层是否被prune了,为0则被prune了
            self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了
    
        def _backward_hook(self, grad): #后向传播计算出gradient后执行的操作
            self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了
            new_grad = torch.ones_like(grad)
            return new_grad
    
        def forward(self, x):
            output = self.conv(x)
            noise = torch.zeros_like(output).normal_()
            output = output + noise #?
            if self.training:
                output.register_hook(self._backward_hook)
            return output
    
        def melt(self):
            if self.conv.groups == 1:
                groups = 1
            elif self.conv.groups == self.conv.out_channels:
                groups = int((self.out_mask != 0).sum())
            else:
                assert False
    
            replacer = nn.Conv2d(
                in_channels = int((self.in_mask != 0).sum()),
                out_channels = int((self.out_mask != 0).sum()),
                kernel_size = self.conv.kernel_size,
                stride = self.conv.stride,
                padding = self.conv.padding,
                dilation = self.conv.dilation,
                groups = groups,
                bias = (self.conv.bias is not None)
            ).to(self.conv.weight.device)
    
            with torch.no_grad():
                if self.conv.groups == 1:
                    replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0])
                else:
                    replacer.weight.set_(self.conv.weight[self.out_mask != 0])
                if self.conv.bias is not None:
                    replacer.bias.set_(self.conv.bias[self.out_mask != 0])
            return replacer
        
        @classmethod
        def transform(cls, net):
            r = []
            def _inject(modules):
                keys = modules.keys()
                for k in keys:
                    if len(modules[k]._modules) > 0:
                        _inject(modules[k]._modules)
                    if isinstance(modules[k], nn.Conv2d):
                        modules[k] = Conv2dObserver(modules[k])
                        r.append(modules[k])
            _inject(net._modules)
            return r

    5)分类最后一层的全连接层怎么变:

    class FinalLinearObserver(Meltable):
        ''' assert was in the last layer. only input was masked '''
        def __init__(self, linear):
            super(FinalLinearObserver, self).__init__()
            assert isinstance(linear, nn.Linear)
            self.linear = linear
            self.in_mask = torch.zeros(linear.weight.shape[1]).to('cpu')
            self.f_hook = linear.register_forward_hook(self._forward_hook) #该linear层前向传播是进行的函数操作
        
        def extra_repr(self):
            return '(%d, %d) -> (%d, %d)' % (
                int(self.linear.weight.shape[1]),
                int(self.linear.weight.shape[0]),
                int((self.in_mask != 0).sum()),
                int(self.linear.weight.shape[0]))
    
        def _forward_hook(self, m, _in, _out):
            x = _in[0]
            self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1) #列相加,每一列求和,输入的data中为0的列是因为那个channels被prune了,
    
        def forward(self, x):
            return self.linear(x)
    
        def melt(self): # 换成prune后的channels数
            with torch.no_grad():
                replacer = nn.Linear(int((self.in_mask != 0).sum()), self.linear.weight.shape[0]).to(self.linear.weight.device)
                replacer.weight.set_(self.linear.weight[:, self.in_mask != 0])
                replacer.bias.set_(self.linear.bias)
            return replacer

    这两个函数的作用在于将卷积层和全连接层分别封装成Conv2dObserver和FinalLinearObserver

    如Conv2dObserver中就会有in_mask和out_mask两个参数,就是分别在训练的前向传播和后向传播中计算channels轴的和,最后为0则说明该轴已经被prune了,即:

        def _forward_hook(self, m, _in, _out):
            x = _in[0] #self.in_mask就是用来记录该channels层是否被prune了,为0则被prune了
            self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了
    
        def _backward_hook(self, grad): #后向传播计算出gradient后执行的操作
            self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1) # 留下channels层,总和为0说明该channels被prune了
            new_grad = torch.ones_like(grad)
            return new_grad

    主要是用在下图标红部分:

    即GBN变成剪枝后的BN的同时,卷积层和全连接层根据相邻的GBN层计算得到的in_mask和out_mask两个参数去剪枝对应的filter,令整个网络channels数是能链接起来的

    6)gate的loss函数:

    def get_gate_sparse_loss(masks, sparse_lambda):
        def _loss_hook(data, label, logits):
            loss = 0.0
            for gbn in masks:
                if isinstance(gbn, GatedBatchNorm2d):
                    loss += gbn.g.abs().sum() 
            return sparse_lambda * loss
    
        return _loss_hook

    这个是计算tock的损失的后半部分,后面看代码它是作为loss_hook的,的确是额外的loss

    对应论文中的:

    查看resnet56_prune中使用的resnet56的网络结构:

    for name, module in pack.net.named_modules():
        print(name)
        print(module)

    返回:

    DataParallel(
      (module): ResNet(
        (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (layer1): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (1): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (2): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (3): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (4): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (5): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (6): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (7): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
          (8): BasicBlock(
            (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential()
          )
        )
        (layer2): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (shortcut): Sequential(
              (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
    ...

    7)

    剩下的Tick-Tock部分可能结合例子来讲比较好讲Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks - 模型压缩 - 3 - 代码学习,VGG16,Resnet

  • 相关阅读:
    BZOJ 1441: Min exgcd
    luogu 1876 开灯 约数+打表
    luogu 1414 又是毕业季II 约数相关
    BZOJ1968: [Ahoi2005]COMMON 约数研究 线性筛
    luogu 3441 [POI2006]MET-Subway 拓扑排序+思维
    Comet OJ
    CF990G GCD Counting 点分治+容斥+暴力
    CF873F Forbidden Indices 后缀自动机+水题
    CF293E Close Vertices 点分治+树状数组
    CF1101D GCD Counting 点分治+质因数分解
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/12087923.html
Copyright © 2011-2022 走看看