zoukankan      html  css  js  c++  java
  • 将tensorflow版本的.ckpt模型转成pytorch的.bin模型

    期末考终于结束了~补下复习之前没来得及记录的问题。

    用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:

    将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta

    加上之前微调使用过的config.json以及vocab.txt文件,运行如下文件后生成pytorch.bin,之后就可以被pytorch得代码调用了。

     1 from __future__ import absolute_import
     2 from __future__ import division
     3 from __future__ import print_function
     4 
     5 import argparse
     6 import torch
     7 
     8 from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
     9 
    10 import logging
    11 logging.basicConfig(level=logging.INFO)
    12 
    13 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    14     # Initialise PyTorch model
    15     config = BertConfig.from_json_file(bert_config_file)
    16     print("Building PyTorch model from configuration: {}".format(str(config)))
    17     model = BertForPreTraining(config)
    18 
    19     # Load weights from tf checkpoint
    20     load_tf_weights_in_bert(model, config, tf_checkpoint_path)
    21 
    22     # Save pytorch-model
    23     print("Save PyTorch model to {}".format(pytorch_dump_path))
    24     torch.save(model.state_dict(), pytorch_dump_path)
    25 
    26 #
    27 if __name__ == "__main__":
    28     parser = argparse.ArgumentParser()
    29     ## Required parameters
    30     parser.add_argument("--tf_checkpoint_path",
    31                         default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
    32                         type = str,
    33                         help = "Path to the TensorFlow checkpoint path.")
    34     parser.add_argument("--bert_config_file",
    35                         default = './chinese_L-12_H-768_A-12_improve1/config.json',
    36                         type = str,
    37                         help = "The config json file corresponding to the pre-trained BERT model. 
    "
    38                             "This specifies the model architecture.")
    39     parser.add_argument("--pytorch_dump_path",
    40                         default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
    41                         type = str,
    42                         help = "Path to the output PyTorch model.")
    43     args = parser.parse_args()
    44     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
    45                                      args.bert_config_file,
    46                                      args.pytorch_dump_path)

  • 相关阅读:
    ————————————————————————————————————————————2002--------------------------------------------————————————————————
    windows accounts
    packages managers
    1.1 VGA(图像显示卡),Graphics Card(图形加速卡),Video Card(视频加速卡),3D Accelerator Card 和 GPU(图形处理器)
    Gartner提出的7种多租户模型
    MICRO-SERVICE
    etcd
    海量存储的一致性和高可用
    saml,sso
    proxy,https,git,tortoise git,ssh-agent,ssh-add,ssh,ssl,rsync
  • 原文地址:https://www.cnblogs.com/cxq1126/p/14277134.html
Copyright © 2011-2022 走看看