zoukankan      html  css  js  c++  java
  • bert微调(1)

    bert微调步骤:

    首先从主函数开刀:

    copy    run_classifier.py 随便重命名 my_classifier.py

    先看主函数:

    if __name__ == "__main__":
      flags.mark_flag_as_required("data_dir")
      flags.mark_flag_as_required("task_name")
      flags.mark_flag_as_required("vocab_file")
      flags.mark_flag_as_required("bert_config_file")
      flags.mark_flag_as_required("output_dir")
      tf.app.run()

    1,data_dir

    flags.mark_flag_as_required("data_dir")中data_dir为数据的路径文件夹,数据格式已经定义好了:

    class InputExample(object):
      """A single training/test example for simple sequence classification."""
    
      def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
    
        Args:
          guid: Unique id for the example.
          text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
          text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
          label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label

    要求的数据格式是:必选参数:guid, text_a,可选参数text_b, label

    其中单句子分类任务不需要text_b,且在test数据样本中不需要输入label

    2,task_name

      processors = {
          "cola": ColaProcessor,
          "mnli": MnliProcessor,
          "mrpc": MrpcProcessor,
          "xnli": XnliProcessor,
      }

    其中task_name表示processors这个字典中的键值对,在bert中给了四个,分别是:"cola","mnli","mrpc","xnli",如果需要别的,另行添加

    值得注意的是:

      task_name = FLAGS.task_name.lower()
    
      if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))
    
      processor = processors[task_name]()
    
      label_list = processor.get_labels()

    task_name是用来选择processor的,在bert的源码中有4个processors,而我们进行微调,需要自定义自己的processor,如下:

    class MrpcProcessor(DataProcessor):
      """Processor for the MRPC data set (GLUE version)."""
    
      def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
    
      def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
    
      def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
    
      def get_labels(self):
        """See base class."""
        return ["0", "1"]  #todo
    
      def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
          if i == 0:
            continue
          guid = "%s-%s" % (set_type, i)
          text_a = tokenization.convert_to_unicode(line[3])
          text_b = tokenization.convert_to_unicode(line[4])
          if set_type == "test":
            label = "0"
          else:
            label = tokenization.convert_to_unicode(line[0])
          examples.append(
              InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

    其实processor表示对数据进行处理的类,它继承了DataProcessor类对输入数据进行预处理,此外,在data_dir文件夹中,我们的文件格式为.tsv格式,由于设定的分类为二分类,我们将label设置为了0,1

    同时_create_examples()中,给定了如何获取guid以及如何给text_a, text_b和label赋值。

    主函数的前两句代码看完了,继续看主函数

    if __name__ == "__main__":
      flags.mark_flag_as_required("data_dir")
      flags.mark_flag_as_required("task_name")
      flags.mark_flag_as_required("vocab_file")
      flags.mark_flag_as_required("bert_config_file")
      flags.mark_flag_as_required("output_dir")
      tf.app.run()

    3,vocab_file, bert_config_file, output_dir

    其中,vocab_file, bert_config_file分别是下载预训练模型的文件,output_dir表示输出的微调之后的model

    此外,在前面所说的.tsv文件格式类似于.csv文件

    train.tsv和dev.tsv文件格式

    标签+“/t”(制表符)+句子

    test文件为

    句子

    4,修改processors字典,添加自己的分类

    processors = {
          "cola": ColaProcessor,
          "mnli": MnliProcessor,
          "mrpc": MrpcProcessor,
          "xnli": XnliProcessor,
          "mrpc": MrpcProcessor
    }

    5,设定参数,进行fine-tune

    python my_classifier.py 
      --task_name=mprc 
      --do_train=true 
      --do_eval=true 
      --data_dir=$GLUE_DIR/MRPC 
      --vocab_file=$BERT_BASE_DIR/vocab.txt 
      --bert_config_file=$BERT_BASE_DIR/bert_config.json 
      --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt 
      --max_seq_length=128 
      --train_batch_size=32 
      --learning_rate=2e-5 
      --num_train_epochs=3.0 
      --output_dir=/tmp/mrpc_output/
  • 相关阅读:
    验证字符串空“” 的表达式
    should not this be a private conversation
    【转】你真的了解word-wrap 和 word-break 的区别吗
    To live is to function,that is all there is in living
    [转] windows 上用程序putty使用 ssh自动登录Linux(Ubuntu)
    Vim/gvim容易忘记的快捷键
    [转] Gvim for windows中块选择的方法
    问题集
    web services
    Trouble Shooting
  • 原文地址:https://www.cnblogs.com/laowangxieboke/p/12836970.html
Copyright © 2011-2022 走看看