期末考终于结束了~补下复习之前没来得及记录的问题。
用google-research官方的bert源码(tensorflow版本)对新的法律语料进行微调,迭代次数为100000次,每隔1000次保存一下模型,得到的结果如下:
将最后三个文件取出,改名为bert_model.ckpt.data-00000-of-00001、bert_model.ckpt.index、bert_model.ckpt.meta
加上之前微调使用过的config.json以及vocab.txt文件,运行如下文件后生成pytorch.bin,之后就可以被pytorch得代码调用了。
1 from __future__ import absolute_import 2 from __future__ import division 3 from __future__ import print_function 4 5 import argparse 6 import torch 7 8 from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert 9 10 import logging 11 logging.basicConfig(level=logging.INFO) 12 13 def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 14 # Initialise PyTorch model 15 config = BertConfig.from_json_file(bert_config_file) 16 print("Building PyTorch model from configuration: {}".format(str(config))) 17 model = BertForPreTraining(config) 18 19 # Load weights from tf checkpoint 20 load_tf_weights_in_bert(model, config, tf_checkpoint_path) 21 22 # Save pytorch-model 23 print("Save PyTorch model to {}".format(pytorch_dump_path)) 24 torch.save(model.state_dict(), pytorch_dump_path) 25 26 # 27 if __name__ == "__main__": 28 parser = argparse.ArgumentParser() 29 ## Required parameters 30 parser.add_argument("--tf_checkpoint_path", 31 default = './chinese_L-12_H-768_A-12_improve1/bert_model.ckpt', 32 type = str, 33 help = "Path to the TensorFlow checkpoint path.") 34 parser.add_argument("--bert_config_file", 35 default = './chinese_L-12_H-768_A-12_improve1/config.json', 36 type = str, 37 help = "The config json file corresponding to the pre-trained BERT model. " 38 "This specifies the model architecture.") 39 parser.add_argument("--pytorch_dump_path", 40 default = './chinese_L-12_H-768_A-12_improve1/pytorch_model.bin', 41 type = str, 42 help = "Path to the output PyTorch model.") 43 args = parser.parse_args() 44 convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 45 args.bert_config_file, 46 args.pytorch_dump_path)