zoukankan      html  css  js  c++  java
  • 【pytorch-huggingface/transformer】 工作流程整理(未完成)

    0、引言

    本文记录使用pytorch、huggingface/transformer 框架工作流程,内容包括:

    1. 数据读取
    2. 数据预处理(split shuffle)
    3. 预训练模型下载和准备(预训练模型参数下载,模型对应Token及超参初始化)
    4. 模型训练、验证、结果测试
    5. 模型本地持久化
    6. 训练过程数据可视化
    7. 部署生成环境上线使用

    1、数据读取

    自己定义数据集Class,通常该类有这么几个功能:

    1. 存X(数据或Emb)、 存y(label)
    2. 根据索引取对应数据
    3. 求数据集大小(求Length)

    如下简例,自定义数据集, 需继承torch.utils.data.Dataset

    class  XXXXDataset(torch.utils.data.Dataset):
        def  __init__(self, encodings, labels):
            '''数据集构造器,创建X 和 y'''
            self.encodings = encodings
            self.labels = labels
    
        def __getitem__(self, idx):
            '''通过idx得到对应元素,数据集item是一个Dict,key是索引'''
            item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
            item ['labels'] = torch.tensor(self.labels[idx])
            return item
    
        def __len__(self):
            return len(self.labels)
    

    2、数据预处理

    使用sklearn中对数据进行随机拆分、shuffle

    from sklearn.model_selection import train_test_split
    train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)
    

    3、模型下载和准备(微调训练器)

    使用transformer库的模型和embedding(Tokenizer),注意NLP中同一个模型可应用于不同的下游任务中,即模型名称的后缀不同

    ## 下载预训练模型
    from transformers import DistilBertTokenizerFast
    tokenzier = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
    
    
    ##使用Trainer来训练模型
    from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
    
    training_args = TrainingArguments(
        output_dir = './results', 
        num_train_epochs =3 ,   # 迭代次数
        per_device_train_batch_size = 16, 
        per_device_eval_batch_size = 64, 
        warmup_steps = 500, 
        weight_decay = 0.01,
        logging_dir = './logs',
        logging_steps = 10, 
    )
    ## 注意Trainer对象是传入训练集和验证集,个人理解在训练时可以对val_dataset进行验证测试
    trainer = Trainer(  
        model = model ,
        args = training_args,
        train_dataset = train_dataset,  
        eval_dataset = val_dataset
    )
    ## 开始微调训练器
    trainer.train()
    

    4、模型训练、验证、测试(待完善)

    使用DataLoader来进行batch迭代每一步训练(验证),测试时可以一次性测试。

    from torch.utils.data import DataLoader
    from transformers import DistilBertForSequenceClassification, AdamW
    ## cpu或GPU计算
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    ## 加载预训练模型
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
    model.to(device)
    model.train()  ## 微调训练模型(而非微调训练训练器)
    
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    
    optim = AdamW(model.parameters(), lr=5e-5) #定义优化器,AdamW表示待权重衰减
    
    for epoch in range(3):
        for batch in train_loader:
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            loss.backward()
            optim.step()
    

    5、模型本地持久化

    import torch
    ###  保存模型结构和参数
    torch.save(model,'./model.pkl' )
    model2 = torch.load('./model.pkl')
    print(model2)
    
    ### 只保存模型参数 
    torch.save(model.state_dict(), './model_params.pkl')
    model_params = torch.load('./model_params.pkl')
    print(model_params)
    
    
  • 相关阅读:
    面试题
    网络编程-1
    excel文件导入数据库--jxl包
    excel文件导入数据库
    1113 Integer Set Partition (25 分)集合分割
    1120 Friend Numbers (20 分)set的使用
    1099 Build A Binary Search Tree (30 分)
    1092 To Buy or Not to Buy (字符串删除)
    1127 ZigZagging on a Tree (30 分)树的层次遍历
    1155 Heap Paths (30 分)判断是否是一个堆
  • 原文地址:https://www.cnblogs.com/andre-ma/p/15268597.html
Copyright © 2011-2022 走看看