zoukankan      html  css  js  c++  java
  • PyTorch | 项目结构解析

    在学习和使用深度学习框架时,复现现有项目代码是必经之路,也能加深对理论知识的理解,提高动手能力。本文参照相关博客整理项目常用组织方式,以及每部分功能,帮助更好的理解复现项目流程,文末提供分类示例项目。

    1 项目组织

    在做深度学习实验或项目时,为了得到最优的模型结果,中间往往需要很多次的尝试和修改。一般项目都包含以下几个部分:

    • 模型定义
    • 数据处理和加载
    • 训练模型(Train&Validate)
    • 训练过程的可视化
    • 测试(Test/Inference)

    另外程序在组织过程中还应该满足以下几个要求:

    • 模型需具有高度可配置性,便于修改参数、修改模型,反复实验
    • 代码应具有良好的组织结构,使人一目了然
    • 代码应具有良好的说明,使其他人能够理解

    2 项目结构

    - checkpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
    - data/:数据相关操作,包括数据预处理、dataset实现等
    - models/:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
    - utils/:可能用到的工具函数,在本次实验中主要是封装了可视化工具
    - config.py:配置文件,所有可配置的变量都集中在此,并提供默认值
    - main.py:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
    - requirements.txt:程序依赖的第三方库
    - README.md:提供程序的必要说明

    3 解析

    3.1 __init__
    - __init__ 可以为空,也可以定义包的属性和方法,但必须存在,其他程序才能从这个目录中读取模块和函数

    3.2 数据加载
    使用Dataset提供数据集的封装,再使用Dataloader实现数据并行加载。

    - def __init__(self..)
    获取图片地址,并根据训练、验证和测试划分数据
    - def __getitem__(self, index):
    返回图片的数据和label
    - def __len__(self):
    返回数据集数量

    train_dataset = DogCat(opt.train_data_root, train=True)
    trainloader = DataLoader(train_dataset,
    batch_size = opt.batch_size,
    shuffle = True,
    num_workers = opt.num_workers)
    
    for ii, (data, label) in enumerate(trainloader):
    train()

    3.3 模型定义
    型的定义主要保存在models/目录下,其中BasicModule是对nn.Module的简易封装,提供快速加载和保存模型的接口。
    nn.Module主要包括save和load两个方法

    from models import AlexNet

    关于模型定义:
    - 尽量使用nn.Sequential(比如AlexNet)
    - 将经常使用的结构封装成子Module(比如GoogLeNet的Inception结构,ResNet的Residual Block结构)
    - 将重复且有规律性的结构,用函数生成(比如VGG的多种变体,ResNet多种变体都是由多个重复卷积层组成)

    3.4 工具函数
    可能会用到一些helper方法,这些方法可以统一放在utils/文件夹下,需要使用时再引入。在本例中主要是封装了可视化工具visdom的一些操作,

    3.5 配置文件
    可配置的参数主要包括:

    数据集参数(文件路径、batch_size等)
    训练参数(学习率、训练epoch等)
    模型参数

    在实际使用时,并不需要每次都修改config.py,只需要通过命令行传入所需参数,覆盖默认配置即可。

    3.6 main函数
    提到了fire
    main中包括train、val、test、help等

    训练的主要步骤如下:

    • 定义网络
    • 定义数据
    • 定义损失函数和优化器
    • 计算重要指标
    • 开始训练
    • 训练网络
    • 可视化各种指标
    • 计算在验证集上的指标

    4 示例分类代码

    #coding:utf8
    from config import opt
    import os
    import torch as t
    import models
    from data.dataset import DogCat
    from torch.utils.data import DataLoader
    from torch.autograd import Variable
    from torchnet import meter
    from utils.visualize import Visualizer
    from tqdm import tqdm
    
    def test(**kwargs):
        opt.parse(kwargs)
        import ipdb;
        ipdb.set_trace()
        # configure model
        model = getattr(models, opt.model)().eval()
        if opt.load_model_path:
            model.load(opt.load_model_path)
        if opt.use_gpu: model.cuda()
    
        # data
        train_data = DogCat(opt.test_data_root,test=True)
        test_dataloader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)
        results = []
        for ii,(data,path) in enumerate(test_dataloader):
            input = t.autograd.Variable(data,volatile = True)
            if opt.use_gpu: input = input.cuda()
            score = model(input)
            probability = t.nn.functional.softmax(score)[:,0].data.tolist()
            # label = score.max(dim = 1)[1].data.tolist()
            
            batch_results = [(path_,probability_) for path_,probability_ in zip(path,probability) ]
    
            results += batch_results
        write_csv(results,opt.result_file)
    
        return results
    
    def write_csv(results,file_name):
        import csv
        with open(file_name,'w') as f:
            writer = csv.writer(f)
            writer.writerow(['id','label'])
            writer.writerows(results)
        
    def train(**kwargs):
        opt.parse(kwargs)
        vis = Visualizer(opt.env)
    
        # step1: configure model
        model = getattr(models, opt.model)()
        if opt.load_model_path:
            model.load(opt.load_model_path)
        if opt.use_gpu: model.cuda()
    
        # step2: data
        train_data = DogCat(opt.train_data_root,train=True)
        val_data = DogCat(opt.train_data_root,train=False)
        train_dataloader = DataLoader(train_data,opt.batch_size,
                            shuffle=True,num_workers=opt.num_workers)
        val_dataloader = DataLoader(val_data,opt.batch_size,
                            shuffle=False,num_workers=opt.num_workers)
        
        # step3: criterion and optimizer
        criterion = t.nn.CrossEntropyLoss()
        lr = opt.lr
        optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay)
            
        # step4: meters
        loss_meter = meter.AverageValueMeter()
        confusion_matrix = meter.ConfusionMeter(2)
        previous_loss = 1e100
    
        # train
        for epoch in range(opt.max_epoch):
            
            loss_meter.reset()
            confusion_matrix.reset()
    
            for ii,(data,label) in tqdm(enumerate(train_dataloader),total=len(train_data)):
    
                # train model 
                input = Variable(data)
                target = Variable(label)
                if opt.use_gpu:
                    input = input.cuda()
                    target = target.cuda()
    
                optimizer.zero_grad()
                score = model(input)
                loss = criterion(score,target)
                loss.backward()
                optimizer.step()
                
                
                # meters update and visualize
                loss_meter.add(loss.data[0])
                confusion_matrix.add(score.data, target.data)
    
                if ii%opt.print_freq==opt.print_freq-1:
                    vis.plot('loss', loss_meter.value()[0])
                    
                    # 进入debug模式
                    if os.path.exists(opt.debug_file):
                        import ipdb;
                        ipdb.set_trace()
    
    
            model.save()
    
            # validate and visualize
            val_cm,val_accuracy = val(model,val_dataloader)
    
            vis.plot('val_accuracy',val_accuracy)
            vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
                        epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr))
            
            # update learning rate
            if loss_meter.value()[0] > previous_loss:          
                lr = lr * opt.lr_decay
                # 第二种降低学习率的方法:不会有moment等信息的丢失
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            
    
            previous_loss = loss_meter.value()[0]
    
    def val(model,dataloader):
        '''
        计算模型在验证集上的准确率等信息
        '''
        model.eval()
        confusion_matrix = meter.ConfusionMeter(2)
        for ii, data in enumerate(dataloader):
            input, label = data
            val_input = Variable(input, volatile=True)
            val_label = Variable(label.type(t.LongTensor), volatile=True)
            if opt.use_gpu:
                val_input = val_input.cuda()
                val_label = val_label.cuda()
            score = model(val_input)
            confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor))
    
        model.train()
        cm_value = confusion_matrix.value()
        accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
        return confusion_matrix, accuracy
    
    def help():
        '''
        打印帮助的信息: python file.py help
        '''
        
        print('''
        usage : python file.py <function> [--args=value]
        <function> := train | test | help
        example: 
                python {0} train --env='env0701' --lr=0.01
                python {0} test --dataset='path/to/dataset/root/'
                python {0} help
        avaiable args:'''.format(__file__))
    
        from inspect import getsource
        source = (getsource(opt.__class__))
        print(source)
    
    if __name__=='__main__':
        import fire
        fire.Fire()
    View Code

    参考:https://github.com/chenyuntc/pytorch-best-practice/blob/master/PyTorch%E5%AE%9E%E6%88%98%E6%8C%87%E5%8D%97.md 

  • 相关阅读:
    教你几招,快速创建 MySQL 五百万级数据,愉快的学习各种优化技巧
    JConsole、VisualVM 依赖的 JMX 技术到底是什么
    面试官你好,我已经掌握了MySQL主从配置和读写分离,你看我还有机会吗?
    Redis 内存压缩原理
    No qualifying bean of type 'org.springframework.transaction.TransactionManager' available: more than one 'primary' bean found among candidates:
    Java 代理模式
    JVisualVM 使用 jmx 连接远程tomcat进行数据分析
    MySQL 同时 delete 多张表的数据
    idea 打包 springboot 项目,tomcat正常启动但访问报404
    算法面试题:一个List<Student>,要求删除里面的男生,不用Linq和Lamda,求各种解,并说明优缺点!
  • 原文地址:https://www.cnblogs.com/geo-will/p/11311418.html
Copyright © 2011-2022 走看看