zoukankan      html  css  js  c++  java
  • BERT:pytorch版,记录一次寻找cls.predictions.bias如何被从全0到load的过程

    一个简单的主入口是这样滴:

    import sys
    sys.path.append('..')
    
    import torch
    from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
    
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    # Tokenized input
    text = "Who was Jim Henson ? Jim Henson was a puppeteer"
    tokenized_text = tokenizer.tokenize(text)
    
    # Mask a token that we will try to predict back with `BertForMaskedLM`
    masked_index = 6
    tokenized_text[masked_index] = '[MASK]'
    assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']
    
    # Convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    # Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
    segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
    # segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    
    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
    segments_tensors = torch.tensor([segments_ids]).to('cuda')
    
    # ========================= BertForMaskedLM ==============================
    # Load pre-trained model (weights)
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    model.to('cuda')
    model.eval()

    入口就是倒数第三行。

    然后进到这里这个from_pretrained方法,这里的代码逻辑还是是有顺序的:

        @classmethod
        def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
            """
            Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
            Download and cache the pre-trained model file if needed.
    
            Params:
                pretrained_model_name: either:
                    - a str with the name of a pre-trained model to load selected in the list of:
                        . `bert-base-uncased`
                        . `bert-large-uncased`
                        . `bert-base-cased`
                        . `bert-large-cased`
                        . `bert-base-multilingual-uncased`
                        . `bert-base-multilingual-cased`
                        . `bert-base-chinese`
                    - a path or url to a pretrained model archive containing:
                        . `bert_config.json` a configuration file for the model
                        . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
                cache_dir: an optional path to a folder in which the pre-trained models will be cached.
                state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
                *inputs, **kwargs: additional input for the specific Bert class
                    (ex: num_labels for BertForSequenceClassification)
            """
            if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
                archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
            else:
                archive_file = pretrained_model_name
            # redirect to the cache, if necessary
            try:
                resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
            except FileNotFoundError:
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find any file "
                    "associated to this path or url.".format(
                        pretrained_model_name,
                        ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
                        archive_file))
                return None
            if resolved_archive_file == archive_file:
                logger.info("loading archive file {}".format(archive_file))
            else:
                logger.info("loading archive file {} from cache at {}".format(
                    archive_file, resolved_archive_file))
            tempdir = None
            if os.path.isdir(resolved_archive_file):
                serialization_dir = resolved_archive_file
            else:
                # Extract archive to temp dir
                tempdir = tempfile.mkdtemp()
                logger.info("extracting archive file {} to temp dir {}".format(
                    resolved_archive_file, tempdir))
                with tarfile.open(resolved_archive_file, 'r:gz') as archive:
                    archive.extractall(tempdir)
                serialization_dir = tempdir
            # Load config
            config_file = os.path.join(serialization_dir, CONFIG_NAME)
            config = BertConfig.from_json_file(config_file)
            logger.info("Model config {}".format(config))
            # Instantiate model.
            model = cls(config, *inputs, **kwargs)
            if state_dict is None:
                weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
                state_dict = torch.load(weights_path)
    
            old_keys = []
            new_keys = []
            for key in state_dict.keys():
                new_key = None
                if 'gamma' in key:
                    new_key = key.replace('gamma', 'weight')
                if 'beta' in key:
                    new_key = key.replace('beta', 'bias')
                if new_key:
                    old_keys.append(key)
                    new_keys.append(new_key)
            for old_key, new_key in zip(old_keys, new_keys):
                state_dict[new_key] = state_dict.pop(old_key)
    
            missing_keys = []
            unexpected_keys = []
            error_msgs = []
            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, '_metadata', None)
            state_dict = state_dict.copy()
            if metadata is not None:
                state_dict._metadata = metadata
    
            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')
            load(model, prefix='' if hasattr(model, 'bert') else 'bert.') #todo: 从这边,model.cls.predictions.bias,这个偏值项的权值被从全0替换
            if len(missing_keys) > 0:
                logger.info("Weights of {} not initialized from pretrained model: {}".format(
                    model.__class__.__name__, missing_keys))
            if len(unexpected_keys) > 0:
                logger.info("Weights from pretrained model not used in {}: {}".format(
                    model.__class__.__name__, unexpected_keys))
            if tempdir:
                # Clean up temp dir
                shutil.rmtree(tempdir)
            return model

    方法虽然长一点,但功能只是简单的载入模型然后load所有的预训练参数

    然后注意其中这个load方法:

            def load(module, prefix=''):
                local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
                module._load_from_state_dict(
                    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + '.')
            load(model, prefix='' if hasattr(model, 'bert') else 'bert.') #todo: 从这边,model.cls.predictions.bias,这个偏值项的权值被从全0替换

    这个load方法载入了所有的预训练参数,那么这个bias到底是指的哪一个bias呢,是这个类:

    class BertLMPredictionHead(nn.Module):
        """
        Arch:
            - BertPredictionHeadTransform (Input=torch.Size([1, 11, 768]), Output=torch.Size([1, 11, 768]))
                - Dense (768, 768)
                - Activation (gelu)
                - LayerNorm
            - Linear (768, 30522)
    
        y = W * x + b
        y = self.decoder.weight * self.decoder + self.bias
        i.e., y = torch.Size([30522, 768]) * torch.Size([768, 30522]) + torch.Size([30522])
    
        Input:
            torch.Size([1, 11, 768])
        Output:
            torch.Size([1, 11, 30522])
    
        The purpose is to Decode.
        """
        def __init__(self, config, bert_model_embedding_weights):
            super(BertLMPredictionHead, self).__init__()
            self.transform = BertPredictionHeadTransform(config)
    
            """
            bert_model_embedding_weights.size():
                torch.Size([30522, 768])
            """
            # The output weights are the same as the input embeddings, but there is
            # an output-only bias for each token.
            self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                     bert_model_embedding_weights.size(0),
                                     bias=False)  # torch.Size([768, 30522])
            self.decoder.weight = bert_model_embedding_weights  # torch.Size([30522, 768])
            self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])
    
        def forward(self, hidden_states):
            """
            hidden_states:
                torch.Size([1, 11, 768])
    
            torch.Size([1, 11, 768]) --> torch.Size([1, 11, 768])
            """
            hidden_states = self.transform(hidden_states)
            """
            To predict the corresponding word in vocab. 
            
            Each of 11 positions has a tensor size=[30522] same to the size of vocab.
            """
            hidden_states = self.decoder(hidden_states) + self.bias  # torch.Size([1, 11, 30522])
            return hidden_states

    就是这个bias:

            self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))  # torch.Size([30522])

    但是为啥我觉得奇怪呢,因为这个类并不是bert的model本身的权值,而是一个扩展类,用来预测【musk】的ids的,然后找到了这个预训练权重的大state_dict,是这样的:

    'cls.predictions.bias' = {Tensor: 30522} tensor([-0.4191, -0.4202, -0.4191,  ..., -0.7900, -0.7822, -0.4965])
    'cls.predictions.transform.dense.weight' = {Tensor: 768} tensor([[ 0.3681,  0.0147,  0.0430,  ...,  0.0384, -0.0296,  0.0227],
            [ 0.0034,  0.2647, -0.0618,  ..., -0.0397, -0.0335,  0.0203],
            [ 0.0179, -0.0060,  0.1788,  ...,  0.0267,  0.0555, -0.0432],
            ...,
            [ 0.0784,  0.0172,  0.0583,  ...,  0.3548,  0.0209, -0.0261],
            [ 0.0175, -0.0466,  0.0834,  ...,  0.0069,  0.2132, -0.0503],
            [-0.0832,  0.0461,  0.0490,  ..., -0.0116, -0.0594,  0.3525]])
    'cls.predictions.transform.dense.bias' = {Tensor: 768} tensor([ 5.3890e-02,  1.0068e-01,  4.5532e-02,  2.7030e-02,  3.8845e-02,
             3.3157e-02,  4.1188e-02,  2.8206e-02,  2.4197e-02,  1.3879e-01,
             4.4386e-02,  4.8806e-02,  3.4415e-02,  5.9976e-02,  4.2772e-02,
             2.5261e-02,  1.0533e-01,  4.1858e-02,  4.9016e-02,  9.8930e-02,
             2.4026e-02,  4.1394e-02,  4.2273e-02,  2.9724e-02,  1.0857e-01,
             4.8379e-02,  3.6337e-02,  5.2781e-02,  2.9902e-02,  2.6919e-02,
             2.1127e-02,  4.8463e-02,  5.7389e-02,  4.8581e-02,  9.8151e-02,
             6.3899e-02,  4.4544e-02,  4.9595e-02,  4.5315e-02,  3.5128e-02,
             3.4962e-02,  6.9260e-02,  4.8273e-02,  4.3921e-02,  3.6126e-02,
             3.9017e-02,  4.7681e-02,  4.1840e-02,  4.2173e-02,  5.2243e-02,
             3.3530e-02,  4.3681e-02,  9.2896e-02, -1.3240e-01,  3.5652e-02,
             3.2232e-02,  6.1398e-02,  3.9744e-02,  4.3546e-02,  3.7697e-02,
             3.2834e-02,  2.5923e-02, -7.8080e-02,  2.7405e-02,  7.5468e-02,
             3.8439e-02,  8.4586e-02,  3.0094e-02,  3.6...
    'cls.predictions.decoder.weight' = {Tensor: 30522} tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
            [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
            [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
            ...,
            [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
            [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
            [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])
    'cls.seq_relationship.weight' = {Tensor: 2} tensor([[-0.0154, -0.0062, -0.0137,  ..., -0.0128, -0.0099,  0.0006],
            [ 0.0058,  0.0120,  0.0128,  ...,  0.0088,  0.0137, -0.0162]])
    'cls.seq_relationship.bias' = {Tensor: 2} tensor([ 0.0211, -0.0021])

    一共一百多个不同名称的权值,其中有这么几个权值命名是cls开头的

    然后这个看了下代码逻辑,是按照名称载入的,所以这个模型的cls.predictions.bias就被替换掉了,本来是全0的。

    我很奇怪,因为我觉得这个dict里面不太应该有这么个东西,后来想了一下,预训练的时候也可能用到了这个musk的功能类,权值就被保存下来了,

    同时,cls.predictions.decoder.weight这个,也好像被重置了,那么它这个模型一开始就把这个weight用Embedding层的weight初始化,是没必要的,可以从代码里发现,这个权值从bert里直接塞过去是这样的:

    Parameter containing:
    tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
            [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
            [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
            ...,
            [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
            [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
            [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]],
           requires_grad=True)
    -0.0102, -0.0615。。。。这个数字和上面第四行那个开头是一致的,可以简单断言这俩权值是相同的。
    也就是Embedding层里面的权重,
    至于结论嘛、。。。。这个预训练权重可以再缩缩。。(弱弱的手动狗头)
  • 相关阅读:
    [转]对Lucene PhraseQuery的slop的理解
    Best jQuery Plugins of 2010
    15 jQuery Plugins To Create A User Friendly Tooltip
    Lucene:基于Java的全文检索引擎简介
    9 Powerful jQuery File Upload Plugins
    Coding Best Practices Using DateTime in the .NET Framework
    Best Image Croppers ready to use for web developers
    28 jQuery Zoom Plugins Creating Stunning Image Effect
    VS2005 + VSS2005 实现团队开发、源代码管理、版本控制(转)
    禁止状态栏显示超链
  • 原文地址:https://www.cnblogs.com/DDBD/p/14470519.html
Copyright © 2011-2022 走看看