zoukankan      html  css  js  c++  java
  • 文本分类(五):transformers库BERT实战,基于BertForSequenceClassification

    一、代码一

    import pandas as pd
    import codecs
    from config.root_path import root
    import os
    from utils.data_process import get_label,text_preprocess
    import json
    from transformers import BertTokenizer
    from torch.utils.data import Dataset, DataLoader, TensorDataset
    import torch
    import re
    import numpy as np
    from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
    import torch.nn as nn
    
    
    class NewsDataset(Dataset):
        def __init__(self, encodings, labels):
            self.encodings = encodings
            self.labels = labels
    
        # 读取单个样本
        def __getitem__(self, idx):
            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
            item['labels'] = torch.tensor(int(self.labels[idx]))
            return item
    
        def __len__(self):
            return len(self.labels)
    
    # 精度计算
    def flat_accuracy(preds, labels):
        pred_flat = np.argmax(preds, axis=1).flatten()
        labels_flat = labels.flatten()
        return np.sum(pred_flat == labels_flat) / len(labels_flat)
    
    class EarlyStopper(object):
    
        def __init__(self, num_trials, save_path):
            self.num_trials = num_trials
            self.trial_counter = 0
            self.best_accuracy = 0
            self.save_path = save_path
    
        def is_continuable(self, model, accuracy):
            if accuracy > self.best_accuracy:
                self.best_accuracy = accuracy
                self.trial_counter = 0
                print("保存模型,指标:{}", accuracy)
                torch.save(model.state_dict(), self.save_path)
                return True
            elif self.trial_counter + 1 < self.num_trials:
                self.trial_counter += 1
                return True
            else:
                return False
    
    class run_bert():
    
        def __init__(self):
    
            data_path = os.path.join(root, "data")
            self.train_path = os.path.join(data_path, "train.txt")
            self.val_path = os.path.join(data_path, "val.txt")
            self.test_path = os.path.join(data_path, "test.txt")
            code_label_path = os.path.join(root, "code_to_label.json")
            if not os.path.exists(code_label_path):
                get_label()
            with open(code_label_path, "r", encoding="utf8") as f:
                self.code_label = json.load(f)
            self.model_name = os.path.join(root, "chkpt", "bert-base-chinese")
            self.num_label = len(self.code_label)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.batch_size = 16
    
        def read_file(self, path):
            sentences = list()
            labels = list()
            with open(path, "r", encoding="utf8") as f:
                for fr in f.readlines():
                    line = fr.strip().split("	")
                    sentences.append(text_preprocess(line[0]))
                    labels.append(self.code_label[line[1]][2])
            return sentences, labels
    
        def get_datas(self):
            train_s, train_l = self.read_file(self.train_path)
            val_s, val_l = self.read_file(self.val_path)
            test_s, test_l = self.read_file(self.test_path)
            return train_s, train_l, val_s, val_l, test_s, test_l
    
        def s_encoding(self, s):
            tokenizer = BertTokenizer.from_pretrained(self.model_name)
            encoding = tokenizer(s, truncation=True, padding=True, max_length=40)
            return encoding
    
        # 训练函数
        def train(self, model, train_loader, optim, device, scheduler, epoch, loss_fn):
            model.train()
            total_train_loss = 0
            iter_num = 0
            total_iter = len(train_loader)
            for batch in train_loader:
                # 正向传播
                optim.zero_grad()
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                logits = outputs[1]
                loss = loss_fn(logits, labels)
                total_train_loss += loss.item()
    
    
                # 反向梯度信息
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                # 参数更新
                optim.step()
                scheduler.step()
    
                iter_num += 1
                if (iter_num % 10 == 0):
                    print("epoth: %d, iter_num: %d, loss: %.4f, %.2f%%" % (
                    epoch, iter_num, loss.item(), iter_num / total_iter * 100))
    
            print("Epoch: %d, Average training loss: %.4f" % (epoch, total_train_loss / len(train_loader)))
    
        def validation(self, model, val_dataloader, device):
            model.eval()
            total_eval_accuracy = 0
            total_eval_loss = 0
            for batch in val_dataloader:
                with torch.no_grad():
                    # 正常传播
                    input_ids = batch['input_ids'].to(device)
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)
                    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    
                loss = outputs[0]
                logits = outputs[1]
                total_eval_loss += loss.item()
                logits = logits.detach().cpu().numpy()
                label_ids = labels.to('cpu').numpy()
                total_eval_accuracy += flat_accuracy(logits, label_ids)
    
            avg_val_accuracy = total_eval_accuracy / len(val_dataloader)
            print("Accuracy: %.4f" % (avg_val_accuracy))
            print("Average testing loss: %.4f" % (total_eval_loss / len(val_dataloader)))
            print("-------------------------------")
            return avg_val_accuracy
    
        def main(self):
            train_s, train_l, val_s, val_l, test_s, test_l  = self.get_datas()
            train_encoding = self.s_encoding(train_s)
            val_encoding = self.s_encoding(val_s)
    
            train_dataset = NewsDataset(train_encoding, train_l)
            val_dataset = NewsDataset(val_encoding, val_l)
    
            model = BertForSequenceClassification.from_pretrained(
                    self.model_name, num_labels=self.num_label)
            model.to(self.device)
            train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
            val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
            optim = AdamW(model.parameters(), lr=2e-5)
            loss_fn = nn.CrossEntropyLoss()
            total_steps = len(train_loader) * 1
            scheduler = get_linear_schedule_with_warmup(optim,
                                                        num_warmup_steps=0,  # Default value in run_glue.py
                                                        num_training_steps=total_steps)
            early_stopper = EarlyStopper(num_trials=5, save_path=f'{os.path.join(root, "chkpt")}/{"bert_classification"}.pt')
            for epoch in range(100):
                print("------------Epoch: %d ----------------" % epoch)
                self.train(model, train_loader, optim, self.device, scheduler, epoch, loss_fn)
                acc = self.validation(model, val_dataloader, self.device)
                if not early_stopper.is_continuable(model, acc):
                    print(f'validation: best auc: {early_stopper.best_accuracy}')
                    break
    
            test_encoding = self.s_encoding(test_s)
            test_dataset = NewsDataset(test_encoding, test_l)
            test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)
            acc = self.validation(model, test_loader, self.device)
            print(f'test acc: {acc}')
    
    if __name__ == '__main__':
        run_bert().main()

    二、分类效果

    模型准确率82%,效果不好。

     

  • 相关阅读:
    DWR组件——基于远程过程调用实现Ajax
    JSTL学习笔记
    EL表达式学习笔记
    JavaScript学习笔记
    原生Ajax使用教程
    Response的返回内容类型
    Tomcat上文件的绝对路径访问笔记
    JSON语言规范与Java中两种解析工具基本使用
    Java生成XML文件与XML文件的写入
    编码问题笔记
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/15127174.html
Copyright © 2011-2022 走看看