zoukankan      html  css  js  c++  java
  • pytorch之对预训练的bert进行剪枝

    大体过程

    对层数进行剪枝

    1、加载预训练的模型;
    2、提取所需要层的权重,并对其进行重命名。比如我们想要第0层和第11层的权重,那么需要将第11层的权重保留下来并且重命名为第1层的名字;
    3、更改模型配置文件(保留几层就是几),并且将第11层的权重赋值给第1层;
    4、保存模型为pytorch_model.bin;
    首先我们来看一下bert具体有哪些权重:

    import torch
    from transformers import BertTokenizer, BertModel
    
    bertModel = BertModel.from_pretrained('bert-base-chinese', output_hidden_states=True, output_attentions=True)
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    for name,param in bertModel.named_parameters():
      print(name, param.shape)
    
    embeddings.word_embeddings.weight torch.Size([21128, 768])
    embeddings.position_embeddings.weight torch.Size([512, 768])
    embeddings.token_type_embeddings.weight torch.Size([2, 768])
    embeddings.LayerNorm.weight torch.Size([768])
    embeddings.LayerNorm.bias torch.Size([768])
    encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.0.attention.self.query.bias torch.Size([768])
    encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.0.attention.self.key.bias torch.Size([768])
    encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.0.attention.self.value.bias torch.Size([768])
    encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.0.attention.output.dense.bias torch.Size([768])
    encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.0.intermediate.dense.bias torch.Size([3072])
    encoder.layer.0.output.dense.weight torch.Size([768, 3072])
    encoder.layer.0.output.dense.bias torch.Size([768])
    encoder.layer.0.output.LayerNorm.weight torch.Size([768])
    encoder.layer.0.output.LayerNorm.bias torch.Size([768])
    encoder.layer.1.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.1.attention.self.query.bias torch.Size([768])
    encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.1.attention.self.key.bias torch.Size([768])
    encoder.layer.1.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.1.attention.self.value.bias torch.Size([768])
    encoder.layer.1.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.1.attention.output.dense.bias torch.Size([768])
    encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.1.intermediate.dense.bias torch.Size([3072])
    encoder.layer.1.output.dense.weight torch.Size([768, 3072])
    encoder.layer.1.output.dense.bias torch.Size([768])
    encoder.layer.1.output.LayerNorm.weight torch.Size([768])
    encoder.layer.1.output.LayerNorm.bias torch.Size([768])
    encoder.layer.2.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.2.attention.self.query.bias torch.Size([768])
    encoder.layer.2.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.2.attention.self.key.bias torch.Size([768])
    encoder.layer.2.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.2.attention.self.value.bias torch.Size([768])
    encoder.layer.2.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.2.attention.output.dense.bias torch.Size([768])
    encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.2.intermediate.dense.bias torch.Size([3072])
    encoder.layer.2.output.dense.weight torch.Size([768, 3072])
    encoder.layer.2.output.dense.bias torch.Size([768])
    encoder.layer.2.output.LayerNorm.weight torch.Size([768])
    encoder.layer.2.output.LayerNorm.bias torch.Size([768])
    encoder.layer.3.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.3.attention.self.query.bias torch.Size([768])
    encoder.layer.3.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.3.attention.self.key.bias torch.Size([768])
    encoder.layer.3.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.3.attention.self.value.bias torch.Size([768])
    encoder.layer.3.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.3.attention.output.dense.bias torch.Size([768])
    encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.3.intermediate.dense.bias torch.Size([3072])
    encoder.layer.3.output.dense.weight torch.Size([768, 3072])
    encoder.layer.3.output.dense.bias torch.Size([768])
    encoder.layer.3.output.LayerNorm.weight torch.Size([768])
    encoder.layer.3.output.LayerNorm.bias torch.Size([768])
    encoder.layer.4.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.4.attention.self.query.bias torch.Size([768])
    encoder.layer.4.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.4.attention.self.key.bias torch.Size([768])
    encoder.layer.4.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.4.attention.self.value.bias torch.Size([768])
    encoder.layer.4.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.4.attention.output.dense.bias torch.Size([768])
    encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.4.intermediate.dense.bias torch.Size([3072])
    encoder.layer.4.output.dense.weight torch.Size([768, 3072])
    encoder.layer.4.output.dense.bias torch.Size([768])
    encoder.layer.4.output.LayerNorm.weight torch.Size([768])
    encoder.layer.4.output.LayerNorm.bias torch.Size([768])
    encoder.layer.5.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.5.attention.self.query.bias torch.Size([768])
    encoder.layer.5.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.5.attention.self.key.bias torch.Size([768])
    encoder.layer.5.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.5.attention.self.value.bias torch.Size([768])
    encoder.layer.5.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.5.attention.output.dense.bias torch.Size([768])
    encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.5.intermediate.dense.bias torch.Size([3072])
    encoder.layer.5.output.dense.weight torch.Size([768, 3072])
    encoder.layer.5.output.dense.bias torch.Size([768])
    encoder.layer.5.output.LayerNorm.weight torch.Size([768])
    encoder.layer.5.output.LayerNorm.bias torch.Size([768])
    encoder.layer.6.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.6.attention.self.query.bias torch.Size([768])
    encoder.layer.6.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.6.attention.self.key.bias torch.Size([768])
    encoder.layer.6.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.6.attention.self.value.bias torch.Size([768])
    encoder.layer.6.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.6.attention.output.dense.bias torch.Size([768])
    encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.6.intermediate.dense.bias torch.Size([3072])
    encoder.layer.6.output.dense.weight torch.Size([768, 3072])
    encoder.layer.6.output.dense.bias torch.Size([768])
    encoder.layer.6.output.LayerNorm.weight torch.Size([768])
    encoder.layer.6.output.LayerNorm.bias torch.Size([768])
    encoder.layer.7.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.7.attention.self.query.bias torch.Size([768])
    encoder.layer.7.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.7.attention.self.key.bias torch.Size([768])
    encoder.layer.7.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.7.attention.self.value.bias torch.Size([768])
    encoder.layer.7.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.7.attention.output.dense.bias torch.Size([768])
    encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.7.intermediate.dense.bias torch.Size([3072])
    encoder.layer.7.output.dense.weight torch.Size([768, 3072])
    encoder.layer.7.output.dense.bias torch.Size([768])
    encoder.layer.7.output.LayerNorm.weight torch.Size([768])
    encoder.layer.7.output.LayerNorm.bias torch.Size([768])
    encoder.layer.8.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.8.attention.self.query.bias torch.Size([768])
    encoder.layer.8.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.8.attention.self.key.bias torch.Size([768])
    encoder.layer.8.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.8.attention.self.value.bias torch.Size([768])
    encoder.layer.8.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.8.attention.output.dense.bias torch.Size([768])
    encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.8.intermediate.dense.bias torch.Size([3072])
    encoder.layer.8.output.dense.weight torch.Size([768, 3072])
    encoder.layer.8.output.dense.bias torch.Size([768])
    encoder.layer.8.output.LayerNorm.weight torch.Size([768])
    encoder.layer.8.output.LayerNorm.bias torch.Size([768])
    encoder.layer.9.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.9.attention.self.query.bias torch.Size([768])
    encoder.layer.9.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.9.attention.self.key.bias torch.Size([768])
    encoder.layer.9.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.9.attention.self.value.bias torch.Size([768])
    encoder.layer.9.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.9.attention.output.dense.bias torch.Size([768])
    encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.9.intermediate.dense.bias torch.Size([3072])
    encoder.layer.9.output.dense.weight torch.Size([768, 3072])
    encoder.layer.9.output.dense.bias torch.Size([768])
    encoder.layer.9.output.LayerNorm.weight torch.Size([768])
    encoder.layer.9.output.LayerNorm.bias torch.Size([768])
    encoder.layer.10.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.10.attention.self.query.bias torch.Size([768])
    encoder.layer.10.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.10.attention.self.key.bias torch.Size([768])
    encoder.layer.10.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.10.attention.self.value.bias torch.Size([768])
    encoder.layer.10.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.10.attention.output.dense.bias torch.Size([768])
    encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.10.intermediate.dense.bias torch.Size([3072])
    encoder.layer.10.output.dense.weight torch.Size([768, 3072])
    encoder.layer.10.output.dense.bias torch.Size([768])
    encoder.layer.10.output.LayerNorm.weight torch.Size([768])
    encoder.layer.10.output.LayerNorm.bias torch.Size([768])
    encoder.layer.11.attention.self.query.weight torch.Size([768, 768])
    encoder.layer.11.attention.self.query.bias torch.Size([768])
    encoder.layer.11.attention.self.key.weight torch.Size([768, 768])
    encoder.layer.11.attention.self.key.bias torch.Size([768])
    encoder.layer.11.attention.self.value.weight torch.Size([768, 768])
    encoder.layer.11.attention.self.value.bias torch.Size([768])
    encoder.layer.11.attention.output.dense.weight torch.Size([768, 768])
    encoder.layer.11.attention.output.dense.bias torch.Size([768])
    encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768])
    encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768])
    encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
    encoder.layer.11.intermediate.dense.bias torch.Size([3072])
    encoder.layer.11.output.dense.weight torch.Size([768, 3072])
    encoder.layer.11.output.dense.bias torch.Size([768])
    encoder.layer.11.output.LayerNorm.weight torch.Size([768])
    encoder.layer.11.output.LayerNorm.bias torch.Size([768])
    pooler.dense.weight torch.Size([768, 768])
    pooler.dense.bias torch.Size([768])
    

    完整代码:

    import os
    import json
    import torch
    import time
    from transformers import BertModel,BertTokenizer
    
    # 提取我们想要的层的权重并重命名
    def get_prune_paramerts(model):
        prune_paramerts = {}
        for name, param in model.named_parameters():
            if 'embeddings' in name:
                prune_paramerts[name] = param
            elif name.startswith('encoder.layer.0.'):
                prune_paramerts[name] = param
            elif name.startswith('encoder.layer.11.'):
                pro_name = name.split('encoder.layer.11.')
                prune_paramerts['encoder.layer.1.' + pro_name[1]] = param
            elif 'pooler' in name:
                prune_paramerts[name] = param
        return prune_paramerts
    
    # 修改配置文件
    def get_prune_config(config):
        prune_config = config
        prune_config['num_hidden_layers'] = 2
        return prune_config
    
    # 缩减模型的层数,并为相对应的层重新进行权重赋值
    def get_prune_model(model, prune_parameters):
        prune_model = model.state_dict()
        for name in list(prune_model.keys()):
            if 'embeddings.position_ids' == name:
                continue
            if 'embeddings' in name:
                prune_model[name] = prune_parameters[name]
            elif name.startswith('encoder.layer.0.'):
                prune_model[name] = prune_parameters[name]
            elif name.startswith('encoder.layer.1.'):
                prune_model[name] = prune_parameters[name]
            elif 'pooler' in name:
                prune_model[name] = prune_parameters[name]
            else:
                del prune_model[name]
        return prune_model
    
    def prune_main():
        model_path = '/data02/gob/project/simpleNLP/model_hub/chinese-bert-wwm-ext/'
        tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
        config = json.loads(open(model_path + 'config.json', 'r').read())
        model = BertModel.from_pretrained(model_path)
        text = '我喜欢吃鱼'
        inputs = tokenizer(text, return_tensors='pt')
        # print(model(**inputs))
    
        out_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        prune_parameters = get_prune_paramerts(model)
        prune_config = get_prune_config(config)
        prune_model = get_prune_model(model, prune_parameters)
        """
        for name,param in model.named_parameters():
            print(name)
        print("===================================")
        for k,v in model.state_dict().items():
            print(k)
        """
        torch.save(prune_model, out_path + 'pytorch_model.bin')
        with open(out_path + 'config.json', 'w') as fp:
            fp.write(json.dumps(prune_config))
        with open(out_path + 'vocab.txt', 'w') as fp:
            fp.write(open(model_path + 'vocab.txt').read())
    
    if __name__ == '__main__':
        # prune_main()
        start_time = time.time()
        # 之后我们就可以像加载bert模型一样加载剪枝层后的模型
        model_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
        tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
        config = json.loads(open(model_path + 'config.json', 'r').read())
        model = BertModel.from_pretrained(model_path)
        text = '我喜欢吃鱼'
        inputs = tokenizer(text, return_tensors='pt')
        for name, param in model.named_parameters():
            print(name, param.shape)
        end_time = time.time()
        print('预测耗时:{}s'.format(end_time-start_time))
    

    对ffn里面的维度进行剪枝

    1、加载预训练的模型;
    2、提取所需要层的权重,并选择topk的值进行裁剪,并重新赋值给该层的参数;
    3、更改模型配置文件(主要是修改维度);
    4、保存模型为pytorch_model.bin;
    具体代码:

    import os
    import json
    import torch
    import time
    from pprint import pprint
    from transformers import BertModel,BertTokenizer
    
    
    def get_prune_ffn_paramerts(model):
        prune_paramerts = {}
        for name, param in model.named_parameters():
            if 'intermediate.dense.weight' in name:
                param = torch.tensor(param.T.topk(384).values, requires_grad=True).T
                prune_paramerts[name] = param
            elif 'intermediate.dense.bias' in name:
                param = torch.tensor(param.topk(384).values, requires_grad=True)
                prune_paramerts[name] = param
            elif 'output.dense.weight' in name and 'attention.output.dense.weight' not in name:
                param = torch.tensor(param.topk(384).values, requires_grad=True)
                prune_paramerts[name] = param
        return prune_paramerts
    
    
    def get_prune_ffn_config(config):
        prune_config = config
        prune_config['intermediate_size'] = 384
        return prune_config
    
    def get_prune_model(model, prune_parameters):
        prune_model = model.state_dict()
        for name in list(prune_model.keys()):
            if name in prune_parameters:
                prune_model[name] = prune_parameters[name]
        return prune_model
    
    
    def prune_main():
        model_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
        tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
        config = json.loads(open(model_path + 'config.json', 'r').read())
        model = BertModel.from_pretrained(model_path)
    
        out_path = '/data02/gob/project/simpleNLP/model_hub/prune-ffn-chinese-bert-wwm-ext/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        prune_parameters = get_prune_ffn_paramerts(model)
        prune_config = get_prune_ffn_config(config)
        prune_model = get_prune_model(model, prune_parameters)
        torch.save(prune_model, out_path + 'pytorch_model.bin')
        with open(out_path + 'config.json', 'w') as fp:
            fp.write(json.dumps(prune_config))
        with open(out_path + 'vocab.txt', 'w') as fp:
            fp.write(open(model_path + 'vocab.txt').read())
    
    if __name__ == '__main__':
        # prune_main()
        if torch.cuda.is_available():
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        model_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
        # model_path = '/data02/gob/project/simpleNLP/model_hub/bert-base-chinese/'
        tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
        config = json.loads(open(model_path + 'config.json', 'r').read())
        model = BertModel.from_pretrained(model_path)
        model.to(device)
        start_time = time.time()
        texts = ['我喜欢吃鱼,我也喜欢打篮球,你知不知道呀。在这个阳光明媚的日子里,我们一起去放风筝'] * 5000
        for text in texts:
            inputs = tokenizer(text, return_tensors='pt')
            for k in inputs.keys():
                inputs[k] = inputs[k].to(device)
            # pprint(inputs)
            # for name, param in model.named_parameters():
            #     print(name, param.shape)
        end_time = time.time()
        print('预测耗时:{}s'.format(end_time-start_time))
    

    对多头进行剪枝和对隐藏层维度进行剪枝

    相对复杂,暂时就不考虑了,一般情况下对层数进行剪枝,简单又方便。

  • 相关阅读:
    Yii笔记之filter用法 j神
    [转载]C# 剪切板编程 Clipboard
    [转载]3521工程
    [原创]获取委托链方式,用于多播委托。
    [转载]C#为应用程序注册快捷键 Ctrl+C Ctrl+V
    【原创】序列化/反序列化
    【原创】WinForm操作EXCEL(第三方插件NPOI)
    【原创】MyXls导出Excel (适用于Winform/WebForm)
    【原创】特性与反射(一)
    【原创】特性与反射(二)
  • 原文地址:https://www.cnblogs.com/xiximayou/p/15193655.html
Copyright © 2011-2022 走看看