zoukankan      html  css  js  c++  java
  • Pytorch 剪枝操作实现

    Pytorch 剪枝操作实现

    首先需要版本为 1.4 以上,

    目前很多模型都取得了十分好的结果, 但是还是参数太多, 占得权重太大, 所以我们的目标是得到一个稀疏的子系数矩阵.

    这个例子是基于 LeNet 的 Pytorch 实现的例子, 我们从 CNN 的角度来剪枝, 其实在全连接层与 RNN 的剪枝应该是类似, 首先导入一些必要的模块

    import torch
    from torch import nn
    import torch.nn.utils.prune as prune
    import torch.nn.functional as F
    

    然后是 LeNet 的网络结构, 不知道为什么这里的网络结构是这样的, 算出来输入的图像是 26x26 的,

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class LeNet(nn.Module):
        def __init__(self):
            super(LeNet, self).__init__()
            # 1 input image channel, 6 output channels, 3x3 square conv kernel
            self.conv1 = nn.Conv2d(1, 6, 3)
            # 第一个卷积层, 输出的向量维度是 6
            self.conv2 = nn.Conv2d(6, 16, 3)
            # 第二个卷积层, 输出的向量维度是 16
            self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
            # 最后将二维向量变成一维
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
            # 2*2 的池化层
            x = F.max_pool2d(F.relu(self.conv2(x)), 2)
            # relu 激活函数层
            x = x.view(-1, int(x.nelement() / x.shape[0]))
            # 除以 batch_size 的大小, 将维度变成一
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    model = LeNet().to(device=device)
    

    这时查看模型的参数:

    module = model.conv1
    print(list(module.named_parameters()))
    

    此时参数包含矩阵的权值与偏置.

    为了剪枝一个模型, 首先要在 torch.nn.utils.prune 中选择一种剪枝方法, 或者使用子类 BasePruningMethod 实现自己的剪枝方法, 然后确定模型以及需要减去的参数, 最后,使用所选修剪技术所需的适当关键字参数,指定修剪参数. 在下面的例子中, 我们将要随机减去 conv1 层中的 30% 的权重参数, module 是函数的第一个参数, name 使用的是参数的字符串标识, amount 表示剪枝的百分比.

    prune.random_unstructured(module, name="weight", amount=0.3)
    

    剪枝行为将 weight 参数名称删除, 并将其替代为新的参数名称, weight_orig , weight_orig存储未修剪的张量版本. 也就是说 weight_orig 是原来的权重,

    上述的剪枝方法会产生一个 mask 矩阵, 叫做 weight_mask , 存储为一个 module buffer , 相当于一个 mask矩阵, 他的维度与 weight 的维度相同, 不同的是 mask 矩阵是一个 0/1 矩阵. 可以通过下面的函数查看 mask 矩阵:

    print(list(module.named_buffers()))
    

    剪枝之后的权重属性 weight 不再是权重的集合, 而是 mask 矩阵与原始矩阵的结合, 所以不再是模型的一个 parameter, 而是一个 attribute.

    最后,使用 PyTorch 的forward_pre_hooks在每次正向传递之前应用修剪。具体来说,如我们在此处所做的那样,在剪枝模块部分,它将为与之相关的每个要修剪的参数获取一个forward_pre_hook。目前为止我们只修剪了名为weight的原始参数,因此将只存在一个 forward_pre_hook, 相当于没有一个剪枝参数就有一个 forward_pre_hook.

    除了对 weight 剪枝, 还可以对 bias 剪枝, 下面是通过 L1 范式剪去三个单元

    prune.l1_unstructured(module, name="bias", amount=3)
    # Prunes tensor corresponding to parameter called name in module by removing the specified amount of (currently unpruned) units with the lowest L1-norm.
    

    Iterative Pruning

    相同的参数在一个模型中可以被多次剪枝, 相当于把多个剪枝核序列化成一个剪枝核, 新的 mask 矩阵与旧的 mask 矩阵的结合使用 PruningContainer 中的 compute_mask 方法. 比如在上面的 module 的 weight 中, 我们除了随机剪枝外还可以通过范式剪枝, 下面是个例子:

    prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
    # As we can verify, this will zero out all the connections corresponding to 
    # 50% (3 out of 6) of the channels, while preserving the action of the 
    # previous mask.
    # 这里的 n 表示剪枝的范式, dim = 0, 表示参数矩阵的维度, 这里卷积层的 dim= 0, 就是核的个数
    print(module.weight)
    

    剪完之后, 核的个数变成原来的一半. mask 矩阵也会自动叠加.

    还可以通过下面的方法查看我们使用了哪些方法剪枝, hook 记录了某个 attribute 的剪枝方法:

    for hook in module._forward_pre_hooks.values():
        if hook._tensor_name == "weight":  # select out the correct hook
            break
    
    print(list(hook))  # pruning history in the container
    

    Serializing a pruned model

    所有相关的张量,包括掩码缓冲区和用于计算修剪的张量的原始参数,都存储在模型的 state_dict 中,因此可以根据需要轻松地序列化和保存.

    我们可以通过下面的方法查看模型中的权重参数:

    >> print(model.state_dict().keys())
    >> odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
    

    Remove pruning re-parametrization

    注意, 这里的删除剪枝的意思并不是真正的删除, 还原到未剪枝的状态. 举个例子, 剪枝之后, 我们的参数 parameters 中的 weight 会变成, 'weight_orig', 而 weight 变成一个属性, 他是 'weight_orig' 与 mask 矩阵结合后的结果, 那么

    prune.remove(module, 'weight')
    

    之后会发生什么呢?

    print(list(module.named_parameters()))
    ('weight', Parameter containing:
    tensor([[[[-0.0000, -0.0000, -0.0000],
              [-0.0000, -0.0000, -0.0000],
              [-0.0000,  0.0000, -0.0000]]],
    .......
    

    也就是说, weight 又变成了 parameters, 剪枝变成永久化.

    Pruning multiple parameters in a model

    多个参数, 多个网络结构的剪枝,

    new_model = LeNet()
    for name, module in new_model.named_modules():
        # prune 20% of connections in all 2D-conv layers 
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.2)
            # 将所有卷积层的权重减去 20%
        # prune 40% of connections in all linear layers 
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.4)
            # 将所有全连接层的权重减去 40%
    
    print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
    

    Global pruning

    之前的剪枝我们都是针对每一层每一层的剪枝, 减去某一层权重的百分比, 对于全局剪枝就是将模型的参数看成一个整体, 减去一部分参数, 对于每一层减去的比例可能不同.

    剪枝的方法可以通过下面的方法:

    model = LeNet()
    
    parameters_to_prune = (
        (model.conv1, 'weight'),
        (model.conv2, 'weight'),
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
    )
    
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.2,
    )
    

    使用自定义的方法剪枝

    要实现自己的修剪功能,可以通过将 BasePruningMethod 基类作为子类来扩展 nn.utils.prune 模块,就像其他所有修剪方法一样. 基类以及完成了下面的方法:

    __call__, apply_mask, apply, prune, and remove

    除了一些特殊的情况, 你不需要重写这些方法以实现新的剪枝方法. 你需要实现的是:

    1. __init__ 构造器
    2. compute_mask 如何根据剪枝策略的逻辑为给定张量计算 mask
    3. 需要说明是全局剪枝, 还是结构剪枝, 或者是非结构剪枝, 这决定了在迭代剪枝是如何结合 mask 矩阵, 换句话说,当剪枝需要剪枝的参数时,当前的剪枝策略应作用于参数的未剪枝部分。指定 PRUNING_TYPE 将启用 PruningContainer 正确识别要修剪的参数的范围.

    比如说, 当我们希望剪枝一个张量中除了某一参数外的所有其他参数的时候, 或者说这个张量已经被部分剪枝的时候, 我们就需要设置: PRUNING_TYPE='unstructured' 因为他只是单独作用与一层, 而不是一个单元或者通道(对应于'structured'), 也不是作用于整个参数(对应于'global')

    class FooBarPruningMethod(prune.BasePruningMethod):
        # 继承自基类 BasePruningMethod
        """Prune every other entry in a tensor
        """
        PRUNING_TYPE = 'unstructured'
        # 类型为 unstructured 类型
    
        def compute_mask(self, t, default_mask):
            mask = default_mask.clone()
            mask.view(-1)[::2] = 0 
            # 定义了 mask 矩阵的构成方法, 每两个数字一个 0
            return mask
    

    然后给出一个调用的例子:

    def foobar_unstructured(module, name):
        """Prunes tensor corresponding to parameter called `name` in `module`
        by removing every other entry in the tensors.
        Modifies module in place (and also return the modified module) 
        by:
        1) adding a named buffer called `name+'_mask'` corresponding to the 
        binary mask applied to the parameter `name` by the pruning method.
        The parameter `name` is replaced by its pruned version, while the 
        original (unpruned) parameter is stored in a new parameter named 
        `name+'_orig'`.
    
        Args:
            module (nn.Module): module containing the tensor to prune
            name (string): parameter name within `module` on which pruning
                    will act.
    
        Returns:
            module (nn.Module): modified (i.e. pruned) version of the input
                module
        
        Examples:
            >>> m = nn.Linear(3, 4)
            >>> foobar_unstructured(m, name='bias')
        """
        FooBarPruningMethod.apply(module, name)
        return module
    
    model = LeNet()
    foobar_unstructured(model.fc3, name='bias')
    
    print(model.fc3.bias_mask)
    

    以上就是Pytorch 剪枝的主要方法, 其实对于复杂的剪枝方法, 只要在 compute_mask 设置特殊的 mask 构成方法就可以了.

  • 相关阅读:
    最受欢迎的北大通选课导读·1[精品]
    社会保险,
    养老金的计算,
    毫秒 后的一个计算,
    返回格式 的数据结构再次改造,
    阶段状态池子,
    生活,-摘
    融合,
    tableview 也可以实现这个效果,
    字体大小 一致起来,
  • 原文地址:https://www.cnblogs.com/wevolf/p/12587225.html
Copyright © 2011-2022 走看看