The Code of Pruning Filters For Efficient ConvNets
1. 代码参考
https://github.com/tyui592/Pruning_filters_for_efficient_convnets
其中主要是用VGG来进行在CIFAR100上的剪枝,理解#args是一些参数,比如VGG的权重路径等信息配置
def prune_network(args, network=None): device = torch.device("cuda" if args.gpu_no >= 0 else "cpu") #配置GPU if network is None: network = VGG(args.vgg, args.data_set) #加载VGG模型 if args.load_path: check_point = torch.load(args.load_path) network.load_state_dict(check_point['state_dict']) # prune network network = prune_step(network, args.prune_layers, args.prune_channels, args.independent_prune_flag) network = network.to(device) print("-*-"*10 + " Prune network " + "-*-"*10) if args.retrain_flag: # update arguemtns for retraing pruned network args.epoch = args.retrain_epoch args.lr = args.retrain_lr args.lr_milestone = None # don't decay learning rate network = train_network(args, network) return network def prune_step(network, prune_layers, prune_channels, independent_prune_flag): network = network.cpu() #剪枝主要是cpu上进行操作 count = 0 # count for indexing 'prune_channels' conv_count = 1 # conv count for 'indexing_prune_layers' dim = 0 # 0: prune corresponding dim of filter weight [out_ch, in_ch, k1, k2],如果是0.表示输入不变,将这个卷积核的输出给去掉,同时一连串后面的bn,以及后面对应的对应的卷积核也需要剪掉,如果是1,表示要把前面的feature maps给剪掉。 residue = None # residue is need to prune by 'independent strategy' #残差 for i in range(len(network.features)): if isinstance(network.features[i], torch.nn.Conv2d): if dim == 1:
#当前是1,表明上一层的filters被剪了,所以这一层要将inchannel的filters按照channel_index同时给剪掉
new_, residue = get_new_conv(network.features[i], dim, channel_index, independent_prune_flag) network.features[i] = new_ dim ^= 1 #当前是0,表明我们这一层要把输出的filters给剪掉。同时得到channel_index if 'conv%d'%conv_count in prune_layers: channel_index = get_channel_index(network.features[i].weight.data, prune_channels[count], residue) new_ = get_new_conv(network.features[i], dim, channel_index, independent_prune_flag) network.features[i] = new_ dim ^= 1 count += 1 else: residue = None conv_count += 1 # bn层也是有通道的,需要将bn层同样做下处理。 elif dim == 1 and isinstance(network.features[i], torch.nn.BatchNorm2d): new_ = get_new_norm(network.features[i], channel_index) network.features[i] = new_ # update to check last conv layer pruned if 'conv13' in prune_layers: network.classifier[0] = get_new_linear(network.classifier[0], channel_index) return network def get_channel_index(kernel, num_elimination, residue=None):
#绝对值排序,按照最小值挑出前num_elimination个的下标 sum_of_kernel = torch.sum(torch.abs(kernel.view(kernel.size(0), -1)), dim=1) if residue is not None: sum_of_kernel += torch.sum(torch.abs(residue.view(residue.size(0), -1)), dim=1) vals, args = torch.sort(sum_of_kernel) return args[:num_elimination].tolist() def index_remove(tensor, dim, index, removed=False):
#根据index进行剪枝 if tensor.is_cuda: tensor = tensor.cpu() size_ = list(tensor.size()) new_size = tensor.size(dim) - len(index) size_[dim] = new_size select_index = list(set(range(tensor.size(dim))) - set(index)) new_tensor = torch.index_select(tensor, dim, torch.tensor(select_index)) if removed: return new_tensor, torch.index_select(tensor, dim, torch.tensor(index)) return new_tensor
def get_new_conv(conv, dim, channel_index, independent_prune_flag=False): if dim == 0: new_conv = torch.nn.Conv2d(in_channels=conv.in_channels, out_channels=int(conv.out_channels - len(channel_index)), kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation) new_conv.weight.data = index_remove(conv.weight.data, dim, channel_index) new_conv.bias.data = index_remove(conv.bias.data, dim, channel_index) return new_conv elif dim == 1: new_conv = torch.nn.Conv2d(in_channels=int(conv.in_channels - len(channel_index)), out_channels=conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation) new_weight = index_remove(conv.weight.data, dim, channel_index, independent_prune_flag) residue = None if independent_prune_flag: new_weight, residue = new_weight new_conv.weight.data = new_weight new_conv.bias.data = conv.bias.data return new_conv, residue def get_new_norm(norm, channel_index): new_norm = torch.nn.BatchNorm2d(num_features=int(norm.num_features - len(channel_index)), eps=norm.eps, momentum=norm.momentum, affine=norm.affine, track_running_stats=norm.track_running_stats) new_norm.weight.data = index_remove(norm.weight.data, 0, channel_index) new_norm.bias.data = index_remove(norm.bias.data, 0, channel_index) if norm.track_running_stats: new_norm.running_mean.data = index_remove(norm.running_mean.data, 0, channel_index) new_norm.running_var.data = index_remove(norm.running_var.data, 0, channel_index) return new_norm
#全连接因为filters数目的变化,也需要进行变化 def get_new_linear(linear, channel_index): new_linear = torch.nn.Linear(in_features=int(linear.in_features - len(channel_index)), out_features=linear.out_features, bias=linear.bias is not None) new_linear.weight.data = index_remove(linear.weight.data, 1, channel_index) new_linear.bias.data = linear.bias.data return new_linear