zoukankan      html  css  js  c++  java
  • 使用bert进行情感分类

    2018年google推出了bert模型,这个模型的性能要远超于以前所使用的模型,总的来说就是很牛。但是训练bert模型是异常昂贵的,对于一般人来说并不需要自己单独训练bert,只需要加载预训练模型,就可以完成相应的任务。下面我将以情感分类为例,介绍使用bert的方法。这里与我们之前调用API写代码有所区别,已经有大神将bert封装成.py文件,我们只需要简单修改一下,就可以直接调用这些.py文件了。

    官方文档

    1. tensorflow版:点击传送门
    2. pytorch版(注意这是一个第三方团队实现的):点击传送门
    3. 论文:点击传送门
      一切以官方论文为准,如果有什么疑问,请仔细阅读官方文档

    具体实现

    我这里使用的是pytorch版本。

    前置需要

    1. 安装pytorch和tensorflow。
    2. 安装PyTorch pretrained bert。(pip install pytorch-pretrained-bert)
    3. 将pytorch-pretrained-BERT提供的文件,整个下载。
    4. 选择并且下载预训练模型。地址:请点击
      注意这里的model是tensorflow版本的,需要进行相应的转换才能在pytorch中使用

    无论是tf版还是pytorch版本,预训练模型都需要三个文件(或者功能类似的)

    1. 预训练模型文件,里面保存的是模型参数。
    2. config文件,用来加载预训练模型。
    3. vocabulary文件,用于后续分词。

    模型转换

    文档里提供了convert_tf_checkpoint_to_pytorch.py 这个脚本来进行模型转换。使用方法如下:

    export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
    
    pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch 
      $BERT_BASE_DIR/bert_model.ckpt 
      $BERT_BASE_DIR/bert_config.json 
      $BERT_BASE_DIR/pytorch_model.bin
    

    修改源码

    这里是需要实现情感分类。只需要用到run_classifier_dataset_utils.py和run_classifier.py这两个文件。run_classifier_dataset_utils.py是用来处理文本的输入,我们只需要添加一个类用来处理输入即可。

    class MyProcessor(DataProcessor):
        '''Processor for the sentiment classification data set'''
    
        def get_train_examples(self, data_dir):
            """See base class."""
            logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
    
        def get_dev_examples(self, data_dir):
            """See base class."""
            return self._create_examples(
                self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
    
        def get_labels(self):
            """See base class."""
            return ["-1", "1"]
    
        def _create_examples(self, lines, set_type):
            """Creates examples for the training and dev sets."""
            examples = []
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                guid = "%s-%s" % (set_type, i)
                text_a = line[0]
                label = line[1]
                examples.append(
                    InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
            return examples
    

    train.tsv和dev.tsv分别表示训练集和测试集。记得要在下面的代码加上之前定义的类。

    def compute_metrics(task_name, preds, labels):
        assert len(preds) == len(labels)
        if task_name == "cola":
            return {"mcc": matthews_corrcoef(labels, preds)}
        elif task_name == "sst-2":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mrpc":
            return acc_and_f1(preds, labels)
        elif task_name == "sts-b":
            return pearson_and_spearman(preds, labels)
        elif task_name == "qqp":
            return acc_and_f1(preds, labels)
        elif task_name == "mnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "mnli-mm":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "qnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "rte":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "wnli":
            return {"acc": simple_accuracy(preds, labels)}
        elif task_name == "my":
            return acc_and_f1(preds, labels)
        else:
            raise KeyError(task_name)
    
    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor,
        "my": MyProcessor
    }
    
    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification",
        "my": "classification"
    }
    

    运行bert

    编辑shell脚本:

    #!/bin/bash 
    export TASK_NAME=my
    
    python run_classifier.py 
      --task_name $TASK_NAME 
      --do_train 
      --do_eval 
      --do_lower_case 
      --data_dir /home/garvey/Yuqinfenxi/ 
      --bert_model /home/garvey/uncased_L-12_H-768_A-12 
      --max_seq_length 410 
      --train_batch_size 8 
      --learning_rate 2e-5 
      --num_train_epochs 3.0 
      --output_dir /home/garvey/bertmodel
    

    运行即可。这里要注意max_seq_length和train_batch_size这两个参数,设置过大是很容易爆掉显存的,一般来说运行bert需要11G左右的显存。

    备注

    max_seq_length是指词的数量而不是指字符的数量。参考代码中的注释:

    The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded.

    对于sequence的理解,网上很多博客都把这个翻译为句子,我个人认为是不准确的,序列是可以包含多个句子的,而不只是单独一个句子。

    注意

    Bert开源的代码中,只提供了train和dev数据,也就是训练集和验证集。对于评测论文标准数据集的时候,只需要把训练集和测试集送进去就可以得到结果,这一过程是没有调参的(没有验证集),都是使用默认参数。但是如果用Bert来打比赛,注意这个时候的测试集是没有标签的,这就需要在源码中加上一个处理test数据集的部分,并且通过验证集来选择参数。

    补充

    在大的预训练模型例如像bert-large在对小的训练集进行精细调整的时候,往往会导致性能退化:模型要么运行良好,要么根本不起作用,在我们用bert-large对一些小数据集进行微调,直接使用默认参数的话二分类的准确率只有0.5,也就是一点作用也没有,这个时候需要对学习率和迭代次数进行一个调整才会有一个正常的结果,这个问题暂时还没有得到解决。

  • 相关阅读:
    佛學概要十四講表
    冰川时代4中英台词全集
    Linux Mysql 每天定时备份
    zabbix拓扑图
    搭建zabbix 3.4
    ★日常工作保养电脑及设备★
    宽带突然断网了,需要做如下应急措施
    预防这几点,可以让你的电脑长久耐用!!!!
    搭建简易的 DISCUZ论坛
    format 的常见用法
  • 原文地址:https://www.cnblogs.com/mlgjb/p/11158009.html
Copyright © 2011-2022 走看看