zoukankan      html  css  js  c++  java
  • Pytorch 训练框架,日志管理,可视化

    torchfurnace

    torchfurnace 是一个集快速训练模型,日志管理,模型checkpoints管理,tensorboard可视化, I/O 加速,模型大小统计于一身的工具包。
    使用这个工具包可以快速构建一个深度学习训练,不需要自己写各种训练逻辑,对于已经定义好的模型也不需要修改,
    可以说是拿来即用

    使用: pip install torchfurnace

    github: https://github.com/tianyu-su/torchfurnace

    下面的例子是快速搭建训练,使用 VGG16 训练 CIFIAR10

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torchvision.models as models
    import torchvision.transforms as transforms
    from torchvision.datasets import CIFAR10
    from torch.optim.lr_scheduler import MultiStepLR
    from torchfurnace import Engine, Parser
    from torchfurnace.utils.function import accuracy
    
    # define training process of your model
    class VGGNetEngine(Engine):
        @staticmethod
        def _on_forward(training, model, inp, target, optimizer=None) -> dict:
            ret = {'loss': object, 'acc1': object, 'acc5': object}
            output = model(inp)
            loss = F.cross_entropy(output, target)
    
            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            ret['loss'] = loss.item()
            ret['acc1'] = acc1.item()
            ret['acc5'] = acc5.item()
            return ret
    
        @staticmethod
        def _get_lr_scheduler(optim) -> list:
            return [MultiStepLR(optim, milestones=[150, 250, 350], gamma=0.1)]
    
    def main():
        # define experiment name
        parser = Parser('TVGG16')
        args = parser.parse_args()
        experiment_name = '_'.join([args.dataset, args.exp_suffix])
    
        # Data
        ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        trainset = CIFAR10(root='data', train=True, download=True, transform=ts)
        testset = CIFAR10(root='data', train=False, download=True, transform=ts)
    
        # define model and optimizer
        net = models.vgg16(pretrained=False, num_classes=10)
        net.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
        net.classifier = nn.Linear(512, 10)
        optimizer = torch.optim.Adam(net.parameters())
    
        # new engine instance
        eng = VGGNetEngine(parser,experiment_name)
        acc1 = eng.learning(net, optimizer, trainset, testset)
        print('Acc1:', acc1)
    
    if __name__ == '__main__':
        import sys
        run_params = '--dataset CIFAR10 -lr 0.1 -bs 128 -j 2 --epochs 400 --adjust_lr'
        sys.argv.extend(run_params.split())
        main()
    
  • 相关阅读:
    27.TreeMap
    26.HashCode
    25.HashTable
    myeclipse快捷键
    spring 配置
    jdbcType和javaType对应关系
    Ajax表单提交
    ajax
    JQuery及Form插件使用
    jsp标准数据库
  • 原文地址:https://www.cnblogs.com/TianyuSu/p/12560193.html
Copyright © 2011-2022 走看看