zoukankan      html  css  js  c++  java
  • BERT:pytorch版,记录一次寻找cls.predictions.bias如何被从全0到load的过程

    一个简单的主入口是这样滴:

    import sys
    sys.path.append('..')
    
    import torch
    from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
    
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Tokenized input
    text = "Who was Jim Henson ? Jim Henson was a puppeteer"
    tokenized_text = tokenizer.tokenize(text)
    
    # Mask a token that we will try to predict back with `BertForMaskedLM`
    masked_index = 6
    tokenized_text[masked_index] = '[MASK]'
    assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']
    
    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
    # segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    
    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
    segments_tensors = torch.tensor([segments_ids]).to('cuda')
    
    # ========================= BertForMaskedLM ==============================
    # Load pre-trained model (weights)
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    model.to('cuda')
    model.eval()

    入口就是倒数第三行。

    然后进到这里这个from_pretrained方法,这里的代码逻辑还是是有顺序的:

        @classmethod
        def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
            """
            Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
            Download and cache the pre-trained model file if needed.
    
            Params:
                pretrained_model_name: either:
                    - a str with the name of a pre-trained model to load selected in the list of:
                        . `bert-base-uncased`
                        . `bert-large-uncased`
                        . `bert-base-cased`
                        . `bert-large-cased`
                        . `bert-base-multilingual-uncased`
                        . `bert-base-multilingual-cased`
                        . `bert-base-chinese`
                    - a path or url to a pretrained model archive containing:
                        . `bert_config.json` a configuration file for the model
                        . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
                cache_dir: an optional path to a folder in which the pre-trained models will be cached.
                state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
                *inputs, **kwargs: additional input for the specific Bert class
                    (ex: num_labels for BertForSequenceClassification)
            """
            if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
                archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
            else:
                archive_file = pretrained_model_name
            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
            except FileNotFoundError:
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find any file "
                    "associated to this path or url.".format(
                        pretrained_model_name,
                        ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
                        archive_file))
                return None
            if resolved_archive_file == archive_file:
                logger.info("loading archive file {}".format(archive_file))
            else:
                logger.info("loading archive file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
            tempdir = None
            if os.path.isdir(resolved_archive_file):
                serialization_dir = resolved_archive_file
            else:
                # Extract archive to temp dir
                tempdir = tempfile.mkdtemp()
                logger.info("extracting archive file {} to temp dir {}".format(
                    resolved_archive_file, tempdir))
                with tarfile.open(resolved_archive_file, 'r:gz') as archive:
                    archive.extractall(tempdir)
                serialization_dir = tempdir
            # Load config
            config_file = os.path.join(serialization_dir, CONFIG_NAME)
            config = BertConfig.from_json_file(config_file)
            logger.info("Model config {}".format(config))
            # Instantiate model.
            model = cls(config, *inputs, **kwargs)
            if state_dict is None:
                weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
                state_dict = torch.load(weights_path)
    
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)
    
            missing_keys = []
            unexpected_keys = []
            error_msgs = []
            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata
    
            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')
            load(model, prefix='' if hasattr(model, 'bert') else 'bert.') #todo: 从这边,model.cls.predictions.bias,这个偏值项的权值被从全0替换
            if len(missing_keys) > 0:
                logger.info("Weights of {} not initialized from pretrained model: {}".format(
                    model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info("Weights from pretrained model not used in {}: {}".format(
                    model.__class__.__name__, unexpected_keys))
            if tempdir:
                # Clean up temp dir
                shutil.rmtree(tempdir)
            return model

    方法虽然长一点,但功能只是简单的载入模型然后load所有的预训练参数

    然后注意其中这个load方法:

            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')
            load(model, prefix='' if hasattr(model, 'bert') else 'bert.') #todo: 从这边,model.cls.predictions.bias,这个偏值项的权值被从全0替换

    这个load方法载入了所有的预训练参数,那么这个bias到底是指的哪一个bias呢,是这个类:

    class BertLMPredictionHead(nn.Module):
        """
        Arch:
            - BertPredictionHeadTransform (Input=torch.Size([1, 11, 768]), Output=torch.Size([1, 11, 768]))
                - Dense (768, 768)
                - Activation (gelu)
                - LayerNorm
            - Linear (768, 30522)
    
        y = W * x + b
        y = self.decoder.weight * self.decoder + self.bias
        i.e., y = torch.Size([30522, 768]) * torch.Size([768, 30522]) + torch.Size([30522])
    
        Input:
            torch.Size([1, 11, 768])
        Output:
            torch.Size([1, 11, 30522])
    
        The purpose is to Decode.
        """
        def __init__(self, config, bert_model_embedding_weights):
            super(BertLMPredictionHead, self).__init__()
            self.transform = BertPredictionHeadTransform(config)
    
            """
            bert_model_embedding_weights.size():
                torch.Size([30522, 768])
            """
            # The output weights are the same as the input embeddings, but there is
            # an output-only bias for each token.
            self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                     bert_model_embedding_weights.size(0),
                                     bias=False)  # torch.Size([768, 30522])
            self.decoder.weight = bert_model_embedding_weights  # torch.Size([30522, 768])
            self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])
    
        def forward(self, hidden_states):
            """
            hidden_states:
                torch.Size([1, 11, 768])
    
            torch.Size([1, 11, 768]) --> torch.Size([1, 11, 768])
            """
            hidden_states = self.transform(hidden_states)
            """
            To predict the corresponding word in vocab. 
            
            Each of 11 positions has a tensor size=[30522] same to the size of vocab.
            """
            hidden_states = self.decoder(hidden_states) + self.bias  # torch.Size([1, 11, 30522])
            return hidden_states

    就是这个bias:

            self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])

    但是为啥我觉得奇怪呢,因为这个类并不是bert的model本身的权值,而是一个扩展类,用来预测【musk】的ids的,然后找到了这个预训练权重的大state_dict,是这样的:

    'cls.predictions.bias' = {Tensor: 30522} tensor([-0.4191, -0.4202, -0.4191,  ..., -0.7900, -0.7822, -0.4965])
    'cls.predictions.transform.dense.weight' = {Tensor: 768} tensor([[ 0.3681,  0.0147,  0.0430,  ...,  0.0384, -0.0296,  0.0227],
            [ 0.0034,  0.2647, -0.0618,  ..., -0.0397, -0.0335,  0.0203],
            [ 0.0179, -0.0060,  0.1788,  ...,  0.0267,  0.0555, -0.0432],
            ...,
            [ 0.0784,  0.0172,  0.0583,  ...,  0.3548,  0.0209, -0.0261],
            [ 0.0175, -0.0466,  0.0834,  ...,  0.0069,  0.2132, -0.0503],
            [-0.0832,  0.0461,  0.0490,  ..., -0.0116, -0.0594,  0.3525]])
    'cls.predictions.transform.dense.bias' = {Tensor: 768} tensor([ 5.3890e-02,  1.0068e-01,  4.5532e-02,  2.7030e-02,  3.8845e-02,
             3.3157e-02,  4.1188e-02,  2.8206e-02,  2.4197e-02,  1.3879e-01,
             4.4386e-02,  4.8806e-02,  3.4415e-02,  5.9976e-02,  4.2772e-02,
             2.5261e-02,  1.0533e-01,  4.1858e-02,  4.9016e-02,  9.8930e-02,
             2.4026e-02,  4.1394e-02,  4.2273e-02,  2.9724e-02,  1.0857e-01,
             4.8379e-02,  3.6337e-02,  5.2781e-02,  2.9902e-02,  2.6919e-02,
             2.1127e-02,  4.8463e-02,  5.7389e-02,  4.8581e-02,  9.8151e-02,
             6.3899e-02,  4.4544e-02,  4.9595e-02,  4.5315e-02,  3.5128e-02,
             3.4962e-02,  6.9260e-02,  4.8273e-02,  4.3921e-02,  3.6126e-02,
             3.9017e-02,  4.7681e-02,  4.1840e-02,  4.2173e-02,  5.2243e-02,
             3.3530e-02,  4.3681e-02,  9.2896e-02, -1.3240e-01,  3.5652e-02,
             3.2232e-02,  6.1398e-02,  3.9744e-02,  4.3546e-02,  3.7697e-02,
             3.2834e-02,  2.5923e-02, -7.8080e-02,  2.7405e-02,  7.5468e-02,
             3.8439e-02,  8.4586e-02,  3.0094e-02,  3.6...
    'cls.predictions.decoder.weight' = {Tensor: 30522} tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
            [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
            [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
            ...,
            [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
            [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
            [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])
    'cls.seq_relationship.weight' = {Tensor: 2} tensor([[-0.0154, -0.0062, -0.0137,  ..., -0.0128, -0.0099,  0.0006],
            [ 0.0058,  0.0120,  0.0128,  ...,  0.0088,  0.0137, -0.0162]])
    'cls.seq_relationship.bias' = {Tensor: 2} tensor([ 0.0211, -0.0021])

    一共一百多个不同名称的权值,其中有这么几个权值命名是cls开头的

    然后这个看了下代码逻辑,是按照名称载入的,所以这个模型的cls.predictions.bias就被替换掉了,本来是全0的。

    我很奇怪,因为我觉得这个dict里面不太应该有这么个东西,后来想了一下,预训练的时候也可能用到了这个musk的功能类,权值就被保存下来了,

    同时,cls.predictions.decoder.weight这个,也好像被重置了,那么它这个模型一开始就把这个weight用Embedding层的weight初始化,是没必要的,可以从代码里发现,这个权值从bert里直接塞过去是这样的:

    Parameter containing:
    tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
            [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
            [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
            ...,
            [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
            [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
            [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
           requires_grad=True)
    -0.0102, -0.0615。。。。这个数字和上面第四行那个开头是一致的,可以简单断言这俩权值是相同的。
    也就是Embedding层里面的权重,
    至于结论嘛、。。。。这个预训练权重可以再缩缩。。(弱弱的手动狗头)
  • 相关阅读:
    sklearn
    Scrapy
    正则表达式re
    BeautifulSoup
    requests
    Python网络爬虫与信息提取
    Matplotlib
    Pandas
    NumPy
    制约大数据处理能力的几个问题
  • 原文地址:https://www.cnblogs.com/DDBD/p/14470519.html
Copyright © 2011-2022 走看看