zoukankan      html  css  js  c++  java
  • 【文本分类-08】BERT

    目录

    1. 大纲概述
    2. 数据集合
    3. 数据处理
    4. 预训练word2vec模型

    一、大纲概述

    文本分类这个系列将会有8篇左右文章,从github直接下载代码,从百度云下载训练数据,在pycharm上导入即可使用,包括基于word2vec预训练的文本分类,与及基于近几年的预训练模型(ELMo,BERT等)的文本分类。总共有以下系列:

    word2vec预训练词向量

    textCNN 模型

    charCNN 模型

    Bi-LSTM 模型

    Bi-LSTM + Attention 模型

    Transformer 模型

    ELMo 预训练模型

    BERT 预训练模型

    二、数据集合

    数据集为IMDB 电影影评,总共有三个数据文件,在/data/rawData目录下,包括unlabeledTrainData.tsv,labeledTrainData.tsv,testData.tsv。在进行文本分类时需要有标签的数据(labeledTrainData),但是在训练word2vec词向量模型(无监督学习)时可以将无标签的数据一起用上。

    训练数据地址:链接:https://pan.baidu.com/s/1-XEwx1ai8kkGsMagIFKX_g 提取码:rtz8

    三、bert模型

      BERT 模型来源于论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding。BERT模型是Google提出的基于双向Transformer构建的语言模型。BERT模型和ELMo有大不同,在之前的预训练模型(包括word2vec,ELMo等)都会生成词向量,这种类别的预训练模型属于domain transfer。BERT 模型是将预训练模型和下游任务模型结合在一起的,也就是说在做下游任务时仍然是用BERT模型,详细介绍请看这篇博文。Google提供了下面七种预训练好的模型文件。

    BERT模型在英文数据集上提供了两种大小的模型,Base和Large。Uncased是意味着输入的词都会转变成小写,cased是意味着输入的词会保存其大写(在命名实体识别等项目上需要)。Multilingual是支持多语言的,最后一个是中文预训练模型。

    在这里我们选择BERT-Base,Uncased。下载下来之后是一个zip文件,解压后有ckpt文件,一个模型参数的json文件,一个词汇表txt文件。

      在应用BERT模型之前,我们需要去github上下载开源代码,我们可以直接clone下来,在这里有一个run_classifier.py文件,在做文本分类项目时,我们需要修改这个文件,主要是添加我们的数据预处理类。clone下来的项目结构如下:

    修改部分如下:

    3.1 新增IMDBProcessor

    在run_classifier.py文件中有一个基类DataProcessor类,在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。接下来我们要定义自己的数据处理的类,我们将我们的类命名为

    class IMDBProcessor(DataProcessor):
        """
        IMDB data processor
        """
        def _read_csv(self, data_dir, file_name):
            with tf.gfile.Open(data_dir + file_name, "r") as f:
                reader = csv.reader(f, delimiter=",", quotechar=None)
                lines = []
                for line in reader:
                    lines.append(line)
            return lines
    
        def get_train_examples(self, data_dir):
            lines = self._read_csv(data_dir, "trainData.csv")
            examples = []
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                guid = "train-%d" % (i)
                text_a = tokenization.convert_to_unicode(line[0])
                label = tokenization.convert_to_unicode(line[1])
                examples.append(
                    InputExample(guid=guid, text_a=text_a, label=label))
            return examples
        def get_dev_examples(self, data_dir):
            lines = self._read_csv(data_dir, "devData.csv")
    
            examples = []
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                guid = "dev-%d" % (i)
                text_a = tokenization.convert_to_unicode(line[0])
                label = tokenization.convert_to_unicode(line[1])
                examples.append(
                    InputExample(guid=guid, text_a=text_a, label=label))
            return examples
    
        def get_test_examples(self, data_dir):
            lines = self._read_csv(data_dir, "testData.csv")
    
            examples = []
            for (i, line) in enumerate(lines):
                if i == 0:
                    continue
                guid = "test-%d" % (i)
                text_a = tokenization.convert_to_unicode(line[0])
                label = tokenization.convert_to_unicode(line[1])
                examples.append(
                    InputExample(guid=guid, text_a=text_a, label=label))
            return examples
    
        def get_labels(self):
            return ["0", "1"]

    3.2 main函数下面的processors

     在这里我们没有直接用基类中的静态方法_read_tsv,因为我们的csv文件是用逗号分隔的,因此就自己定义了一个_read_csv的方法,其余的方法就是读取训练集,验证集,测试集和标签。在这里标签就是一个列表,将我们的类别标签放入就行。训练集,验证集和测试集都是返回一个InputExample对象的列表。InputExample是run_classifier.py中定义的一个类。在这个类中定义了text_a和text_b,说明是支持句子对的输入的,不过我们这里做文本分类只有一个句子的输入,因此text_b可以不传参。

      另外从上面我们自定义的数据处理类中可以看出,训练集和验证集是保存在不同文件中的,因此我们需要将我们之前预处理好的数据提前分割成训练集和验证集,并存放在同一个文件夹下面,文件的名称要和类中方法里的名称相同。

      到这里之后我们已经准备好了我们的数据集,并定义好了数据处理类,此时我们需要将我们的数据处理类加入到run_classifier.py文件中的main函数下面的processors字典中,结果如下:

        1 	def main(_):
        2 	  tf.logging.set_verbosity(tf.logging.INFO)
        3 	
        4 	  processors = {
        5 	      "cola": ColaProcessor,
        6 	      "mnli": MnliProcessor,
        7 	      "mrpc": MrpcProcessor,
        8 	      "xnli": XnliProcessor,
        9 	      "imdb": IMDBProcessor     #这一句是新增的,用于调用完成本次任务
       10 	  }

    3.3 配置脚本

    --data_dir=../MY_DATASET/

    --task_name=imdb

    --vocab_file=../BERT_BASE_DIR/uncased_L-12_H-768_A-12/vocab.txt

    --bert_config_file=../BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_config.json

    --do_train=true

    --do_eval=true

    --init_checkpoint=../BERT_BASE_DIR/uncased_L-12_H-768_A-12/bert_model.ckpt

    --max_seq_length=128

    --train_batch_size=16

    --learning_rate=5e-5

    --num_train_epochs=2.0

    --output_dir=../output/

    具体在pycharm这种IDE运行的办法可见这篇博文

       

    相关代码可见:https://github.com/yifanhunter/NLP_textClassifier

    主要参考

    【1】 https://home.cnblogs.com/u/jiangxinyang/

  • 相关阅读:
    【安装软件的点点滴滴】
    【自然语言处理】LDA
    【sklearn】数据预处理 sklearn.preprocessing
    【sklearn】中文文档
    【MySql】update用法
    DotNet Core
    ASP.NET MVC
    ADO.NET
    RESTful API
    C#
  • 原文地址:https://www.cnblogs.com/yifanrensheng/p/13369314.html
Copyright © 2011-2022 走看看