zoukankan      html  css  js  c++  java
  • Pruning Filters For Efficient ConvNets 剪枝代码小结

    The Code of Pruning Filters For Efficient ConvNets

    1. 代码参考

     https://github.com/tyui592/Pruning_filters_for_efficient_convnets

     其中主要是用VGG来进行在CIFAR100上的剪枝,理解#args是一些参数,比如VGG的权重路径等信息配置

    def prune_network(args, network=None):
        device = torch.device("cuda" if args.gpu_no >= 0 else "cpu") #配置GPU
        if network is None:
            network = VGG(args.vgg, args.data_set)                   #加载VGG模型
            if args.load_path:
                check_point = torch.load(args.load_path)
                network.load_state_dict(check_point['state_dict'])
    
        # prune network
        network = prune_step(network, args.prune_layers, args.prune_channels, args.independent_prune_flag)
        network = network.to(device)
        print("-*-"*10 + "
    	Prune network
    " + "-*-"*10)
    
        if args.retrain_flag:
            # update arguemtns for retraing pruned network
            args.epoch = args.retrain_epoch
            args.lr = args.retrain_lr
            args.lr_milestone = None # don't decay learning rate
            network = train_network(args, network)
        return network
    
    def prune_step(network, prune_layers, prune_channels, independent_prune_flag):
        network = network.cpu() #剪枝主要是cpu上进行操作
        count = 0 # count for indexing 'prune_channels'
        conv_count = 1 # conv count for 'indexing_prune_layers'
        dim = 0 # 0: prune corresponding dim of filter weight [out_ch, in_ch, k1, k2],如果是0.表示输入不变,将这个卷积核的输出给去掉,同时一连串后面的bn,以及后面对应的对应的卷积核也需要剪掉,如果是1,表示要把前面的feature maps给剪掉。
        residue = None # residue is need to prune by 'independent strategy'      #残差
        for i in range(len(network.features)):
            if isinstance(network.features[i], torch.nn.Conv2d):
                if dim == 1:
    #当前是1,表明上一层的filters被剪了,所以这一层要将inchannel的filters按照channel_index同时给剪掉
    new_, residue
    = get_new_conv(network.features[i], dim, channel_index, independent_prune_flag) network.features[i] = new_ dim ^= 1           #当前是0,表明我们这一层要把输出的filters给剪掉。同时得到channel_index if 'conv%d'%conv_count in prune_layers: channel_index = get_channel_index(network.features[i].weight.data, prune_channels[count], residue) new_ = get_new_conv(network.features[i], dim, channel_index, independent_prune_flag) network.features[i] = new_ dim ^= 1 count += 1 else: residue = None conv_count += 1      # bn层也是有通道的,需要将bn层同样做下处理。 elif dim == 1 and isinstance(network.features[i], torch.nn.BatchNorm2d): new_ = get_new_norm(network.features[i], channel_index) network.features[i] = new_ # update to check last conv layer pruned if 'conv13' in prune_layers: network.classifier[0] = get_new_linear(network.classifier[0], channel_index) return network def get_channel_index(kernel, num_elimination, residue=None):
    #绝对值排序,按照最小值挑出前num_elimination个的下标 sum_of_kernel
    = torch.sum(torch.abs(kernel.view(kernel.size(0), -1)), dim=1) if residue is not None: sum_of_kernel += torch.sum(torch.abs(residue.view(residue.size(0), -1)), dim=1) vals, args = torch.sort(sum_of_kernel) return args[:num_elimination].tolist() def index_remove(tensor, dim, index, removed=False):
      #根据index进行剪枝
    if tensor.is_cuda: tensor = tensor.cpu() size_ = list(tensor.size()) new_size = tensor.size(dim) - len(index) size_[dim] = new_size select_index = list(set(range(tensor.size(dim))) - set(index)) new_tensor = torch.index_select(tensor, dim, torch.tensor(select_index)) if removed: return new_tensor, torch.index_select(tensor, dim, torch.tensor(index)) return new_tensor
    def get_new_conv(conv, dim, channel_index, independent_prune_flag=False): if dim == 0: new_conv = torch.nn.Conv2d(in_channels=conv.in_channels, out_channels=int(conv.out_channels - len(channel_index)), kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation) new_conv.weight.data = index_remove(conv.weight.data, dim, channel_index) new_conv.bias.data = index_remove(conv.bias.data, dim, channel_index) return new_conv elif dim == 1: new_conv = torch.nn.Conv2d(in_channels=int(conv.in_channels - len(channel_index)), out_channels=conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation) new_weight = index_remove(conv.weight.data, dim, channel_index, independent_prune_flag) residue = None if independent_prune_flag: new_weight, residue = new_weight new_conv.weight.data = new_weight new_conv.bias.data = conv.bias.data return new_conv, residue def get_new_norm(norm, channel_index): new_norm = torch.nn.BatchNorm2d(num_features=int(norm.num_features - len(channel_index)), eps=norm.eps, momentum=norm.momentum, affine=norm.affine, track_running_stats=norm.track_running_stats) new_norm.weight.data = index_remove(norm.weight.data, 0, channel_index) new_norm.bias.data = index_remove(norm.bias.data, 0, channel_index) if norm.track_running_stats: new_norm.running_mean.data = index_remove(norm.running_mean.data, 0, channel_index) new_norm.running_var.data = index_remove(norm.running_var.data, 0, channel_index) return new_norm

    #全连接因为filters数目的变化,也需要进行变化
    def get_new_linear(linear, channel_index): new_linear = torch.nn.Linear(in_features=int(linear.in_features - len(channel_index)), out_features=linear.out_features, bias=linear.bias is not None) new_linear.weight.data = index_remove(linear.weight.data, 1, channel_index) new_linear.bias.data = linear.bias.data return new_linear
  • 相关阅读:
    loadrunner(预测系统行为和性能的负载测试工具)
    SOA(面向服务的架构)
    Acunetix Web Vulnerability Scanner(WVS)(Acunetix网络漏洞扫描器)
    Redis主从复制
    Redis启动方式
    Struts2.5以上版本There is no Action mapped for namespace [/] and action name [userAction_login] associated with context path []
    使用Slf4j集成Log4j2构建项目日志系统的完美解决方案
    Write operations are not allowed in read-only mode (FlushMode.MANUAL): Turn your Session into FlushMode.COMMIT/AUTO or remove 'readOnly' marker from transaction definition.
    intellij 编译 springmvc+hibernate+spring+maven 找不到hbm.xml映射文件
    MySQL8连接数据库
  • 原文地址:https://www.cnblogs.com/zonechen/p/13373259.html
Copyright © 2011-2022 走看看