zoukankan      html  css  js  c++  java
  • 一个简洁、好用的Pytorch训练模板

    一个简洁、好用的Pytorch训练模板

    代码地址:https://github.com/KinglittleQ/Pytorch-Template

    怎么使用

    1) 更改template.py

    替换 __init__方法中的内容,增添自己的模型、优化器、评估器等等.

    class Model():
    
        def __init__(self, args):
            self.writer = tX.SummaryWriter(log_dir=None, comment='')
            self.train_logger = None  # not neccessary
            self.eval_logger = None  # not neccessary
            self.args = args  # not neccessary
    
            self.step = 0
            self.epoch = 0
            self.best_error = float('Inf')
    
            self.model = None
            self.optimizer = None
            self.criterion = None
            self.metric = None
    
            self.train_loader = None
            self.test_loader = None
    
            self.device = None
    
            self.ckpt_dir = None
            self.log_per_step = None
    

    2) 写部分训练代码

    你所需要做的只是写一个简单的for循环:

    model = Model()
    
    for epoch in range(n_epochs):
        model.train()
        if (epoch + 1) % eval_per_epoch == 0:
            model.eval()
    
    print('Done!!!')
    

    3) 继续训练

    继续训练十分方便,只需要加载之前保存好的模型。

    model = Model()
    if model_path:
        model.load_state(model_path)
    
    for i in range(n_epochs):
        model.train()
        if model.epoch % eval_per_epoch == 0:
            model.eval()
    

    Example

    • LeNet: 训练一个LeNet对MNIST手写数字进行分类

      • 训练过程如下:

        ......
        epoch 1 step 3400   loss 0.0434
        epoch 1 step 3500   loss 0.0331
        epoch 1 step 3600   loss 0.00188
        epoch 1 step 3700   loss 0.00341
        save model at ../modelsest.pth.tar
        save model at ../models1.pth.tar
        epoch 1 error 0.0237
        epoch 2 step 3800   loss 0.0201
        epoch 2 step 3900   loss 0.00523
        epoch 2 step 4000   loss 0.0236
        ......
        
      • 使用tensorboard可视化输出:

        tensorboard --logdir example/LeNet/log
        


      • 继续训练

        load model from checkpoint/9.pth.tar
        epoch 10    step 33800  loss 0.000128
        epoch 10    step 33900  loss 6.64e-06
        epoch 10    step 34000  loss 0.000613
        epoch 10    step 34100  loss 2.41e-05
        ......
        
  • 相关阅读:
    使用httperrequest,模拟发送及接收Json请求
    VI/VIM 常用命令
    Robot Framework开发系统关键字详细
    Python logging模块使用记录
    反编译app方法
    python+appium使用记录
    查看apk包及Activity名方法
    Robot Framework使用技巧
    git 常用使用及问题记录
    多个git账户生成多份rsa秘钥实现多个账户同时使用配置
  • 原文地址:https://www.cnblogs.com/magic-girl/p/pytorch_template.html
Copyright © 2011-2022 走看看