zoukankan      html  css  js  c++  java
  • NLP(三十):BertForSequenceClassification:Kaggle的bert文本分类,基于transformers的BERT分类

    Bert是非常强化的NLP模型,在文本分类的精度非常高。本文将介绍Bert中文文本分类的基础步骤,文末有代码获取方法。

    步骤1:读取数据

    本文选取了头条新闻分类数据集来完成分类任务,此数据集是根据头条新闻的标题来完成分类。

    101 京城最值得你来场文化之旅的博物馆_!_保利集团,马未都,中国科学技术馆,博物馆,新中国
    101 发酵床的垫料种类有哪些?哪种更好?
    101 上联:黄山黄河黄皮肤黄土高原。怎么对下联?
    101 林徽因什么理由拒绝了徐志摩而选择梁思成为终身伴侣?
    101 黄杨木是什么树?
    

    首先需要下载数据,并解压数据:

    wget http://github.com/skdjfla/toutiao-text-classfication-dataset/raw/master/toutiao_cat_data.txt.zip
    !unzip toutiao_cat_data.txt.zip
    

    按照数据集格式读取新闻标题和新闻标签:

    import pandas as pd
    import codecs
    
    # 标签
    news_label = [int(x.split('_!_')[1])-100 
                      for x in codecs.open('toutiao_cat_data.txt')]
    # 文本
    news_text = [x.strip().split('_!_')[-1] if x.strip()[-3:] != '_!_' else x.strip().split('_!_')[-2]
                     for x in codecs.open('toutiao_cat_data.txt')]
    

    步骤2:划分数据集

    借助train_test_split划分20%的数据为验证集,并保证训练集和验证部分类别同分布。

    import torch
    from sklearn.model_selection import train_test_split
    from torch.utils.data import Dataset, DataLoader, TensorDataset
    import numpy as np
    import pandas as pd
    import random
    import re
    
    # 划分为训练集和验证集
    # stratify 按照标签进行采样,训练集和验证部分同分布
    x_train, x_test, train_label, test_label =  train_test_split(news_text[:], 
                          news_label[:], test_size=0.2, stratify=news_label[:])
    

    步骤3:对文本进行编码

    使用transformers对文本进行转换,这里使用的是bert-base-chinese模型,所以加载的Tokenizer也要对应。

    # transformers bert相关的模型使用和加载
    from transformers import BertTokenizer
    # 分词器,词典
    
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    train_encoding = tokenizer(x_train, truncation=True, padding=True, max_length=64)
    test_encoding = tokenizer(x_test, truncation=True, padding=True, max_length=64)
    

    使用编码后的数据构建Dataset:

    # 数据集读取
    class NewsDataset(Dataset):
        def __init__(self, encodings, labels):
            self.encodings = encodings
            self.labels = labels
        
        # 读取单个样本
        def __getitem__(self, idx):
            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
            item['labels'] = torch.tensor(int(self.labels[idx]))
            return item
        
        def __len__(self):
            return len(self.labels)
    
    train_dataset = NewsDataset(train_encoding, train_label)
    test_dataset = NewsDataset(test_encoding, test_label)
    

    这里dataset是直接读取文本在经过所以加载的Tokenizer处理后的数据,主要的含义如下:

    • input_ids:字的编码
    • token_type_ids:标识是第一个句子还是第二个句子
    • attention_mask:标识是不是填充

    步骤4:定义Bert模型

    由于这里是文本分类任务,所以直接使用BertForSequenceClassification完成加载即可,这里需要制定对应的类别数量。

    from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
    model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=17)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # 单个读取到批量读取
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)
    
    # 优化方法
    optim = AdamW(model.parameters(), lr=2e-5)
    total_steps = len(train_loader) * 1
    scheduler = get_linear_schedule_with_warmup(optim, 
                                                num_warmup_steps = 0, # Default value in run_glue.py
                                                num_training_steps = total_steps)
    

    步骤5:模型训练与验证

    使用常规的正向传播和反向传播即可,在训练过程中计算类别准确率。

    # 训练函数
    def train():
        model.train()
        total_train_loss = 0
        iter_num = 0
        total_iter = len(train_loader)
        for batch in train_loader:
            # 正向传播
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            total_train_loss += loss.item()
            
            # 反向梯度信息
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # 参数更新
            optim.step()
            scheduler.step()
    
            iter_num += 1
            if(iter_num % 100==0):
                print("epoth: %d, iter_num: %d, loss: %.4f, %.2f%%" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))
            
        print("Epoch: %d, Average training loss: %.4f"%(epoch, total_train_loss/len(train_loader)))
        
    def validation():
        model.eval()
        total_eval_accuracy = 0
        total_eval_loss = 0
        for batch in test_dataloader:
            with torch.no_grad():
                # 正常传播
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            
            loss = outputs[0]
            logits = outputs[1]
    
            total_eval_loss += loss.item()
            logits = logits.detach().cpu().numpy()
            label_ids = labels.to('cpu').numpy()
            total_eval_accuracy += flat_accuracy(logits, label_ids)
            
        avg_val_accuracy = total_eval_accuracy / len(test_dataloader)
        print("Accuracy: %.4f" % (avg_val_accuracy))
        print("Average testing loss: %.4f"%(total_eval_loss/len(test_dataloader)))
        print("-------------------------------")
        
    
    for epoch in range(4):
        print("------------Epoch: %d ----------------" % epoch)
        train()
        validation()
    

    训练一个Epoch的输出精度已经达到87%,Bert模型非常有效。

    ------------Epoch: 0 ----------------
    epoth: 0, iter_num: 2500, loss: 0.7519, 100.00%
    Epoch: 0, Average training loss: 0.6181
    Accuracy: 0.8747
    Average testing loss: 0.4602
    -------------------------------

    转自:https://zhuanlan.zhihu.com/p/388009679
  • 相关阅读:
    忠告20岁的年轻人
    mac电脑好用的工具总结
    idea 配置
    mac 安装mysql5.7.28附安装包
    国内外优秀网站收集
    MySql 数据库、数据表操作
    Java 高效代码50例
    Mac 修改版本号
    sql 语句系列(删库跑路系列)[八百章之第七章]
    sql 语句系列(更新系列)[八百章之第六章]
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15067067.html
Copyright © 2011-2022 走看看