zoukankan      html  css  js  c++  java
  • 将tensorflow版本的.ckpt模型转成pytorch的.bin模型

    期末考终于结束了~补下复习之前没来得及记录的问题。

    用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:

    将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta

    加上之前微调使用过的config.json以及vocab.txt文件,运行如下文件后生成pytorch.bin,之后就可以被pytorch得代码调用了。

     1 from __future__ import absolute_import
     2 from __future__ import division
     3 from __future__ import print_function
     4 
     5 import argparse
     6 import torch
     7 
     8 from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
     9 
    10 import logging
    11 logging.basicConfig(level=logging.INFO)
    12 
    13 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    14     # Initialise PyTorch model
    15     config = BertConfig.from_json_file(bert_config_file)
    16     print("Building PyTorch model from configuration: {}".format(str(config)))
    17     model = BertForPreTraining(config)
    18 
    19     # Load weights from tf checkpoint
    20     load_tf_weights_in_bert(model, config, tf_checkpoint_path)
    21 
    22     # Save pytorch-model
    23     print("Save PyTorch model to {}".format(pytorch_dump_path))
    24     torch.save(model.state_dict(), pytorch_dump_path)
    25 
    26 #
    27 if __name__ == "__main__":
    28     parser = argparse.ArgumentParser()
    29     ## Required parameters
    30     parser.add_argument("--tf_checkpoint_path",
    31                         default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt',
    32                         type = str,
    33                         help = "Path to the TensorFlow checkpoint path.")
    34     parser.add_argument("--bert_config_file",
    35                         default = './chinese_L-12_H-768_A-12_improve1/config.json',
    36                         type = str,
    37                         help = "The config json file corresponding to the pre-trained BERT model. 
    "
    38                             "This specifies the model architecture.")
    39     parser.add_argument("--pytorch_dump_path",
    40                         default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin',
    41                         type = str,
    42                         help = "Path to the output PyTorch model.")
    43     args = parser.parse_args()
    44     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
    45                                      args.bert_config_file,
    46                                      args.pytorch_dump_path)

  • 相关阅读:
    java.sql.SQLException: 数字溢出 的解决办法
    oracle数据库创建表,序列及添加代码案例
    Oracle创建用户、角色、授权、建表
    HttpSession与Hibernate中Session的区别
    RuntimeException与CheckedException
    >Hibernate 报错:this project is not a myeclipse hibernate project . assuming hibernate 3 cap
    解决java web项目导入后出现的问题 ---cannot be read or is not a valid ZIP file
    JDK,JRE,JVM区别与联系
    最爱的天籁之音
    applicationContext.xml 基本配置
  • 原文地址:https://www.cnblogs.com/cxq1126/p/14277134.html
Copyright © 2011-2022 走看看