zoukankan      html  css  js  c++  java
  • 一个简洁、好用的Pytorch训练模板

    一个简洁、好用的Pytorch训练模板

    代码地址:https://github.com/KinglittleQ/Pytorch-Template

    怎么使用

    1) 更改template.py

    替换 __init__方法中的内容,增添自己的模型、优化器、评估器等等.

    class Model():
    
        def __init__(self, args):
            self.writer = tX.SummaryWriter(log_dir=None, comment='')
            self.train_logger = None  # not neccessary
            self.eval_logger = None  # not neccessary
            self.args = args  # not neccessary
    
            self.step = 0
            self.epoch = 0
            self.best_error = float('Inf')
    
            self.model = None
            self.optimizer = None
            self.criterion = None
            self.metric = None
    
            self.train_loader = None
            self.test_loader = None
    
            self.device = None
    
            self.ckpt_dir = None
            self.log_per_step = None
    

    2) 写部分训练代码

    你所需要做的只是写一个简单的for循环:

    model = Model()
    
    for epoch in range(n_epochs):
        model.train()
        if (epoch + 1) % eval_per_epoch == 0:
            model.eval()
    
    print('Done!!!')
    

    3) 继续训练

    继续训练十分方便,只需要加载之前保存好的模型。

    model = Model()
    if model_path:
        model.load_state(model_path)
    
    for i in range(n_epochs):
        model.train()
        if model.epoch % eval_per_epoch == 0:
            model.eval()
    

    Example

    • LeNet: 训练一个LeNet对MNIST手写数字进行分类

      • 训练过程如下:

        ......
        epoch 1 step 3400   loss 0.0434
        epoch 1 step 3500   loss 0.0331
        epoch 1 step 3600   loss 0.00188
        epoch 1 step 3700   loss 0.00341
        save model at ../modelsest.pth.tar
        save model at ../models1.pth.tar
        epoch 1 error 0.0237
        epoch 2 step 3800   loss 0.0201
        epoch 2 step 3900   loss 0.00523
        epoch 2 step 4000   loss 0.0236
        ......
        
      • 使用tensorboard可视化输出:

        tensorboard --logdir example/LeNet/log
        


      • 继续训练

        load model from checkpoint/9.pth.tar
        epoch 10    step 33800  loss 0.000128
        epoch 10    step 33900  loss 6.64e-06
        epoch 10    step 34000  loss 0.000613
        epoch 10    step 34100  loss 2.41e-05
        ......
        
  • 相关阅读:
    Java Web 网络留言板2 JDBC数据源 (连接池技术)
    Java Web 网络留言板3 CommonsDbUtils
    Java Web ConnectionPool (连接池技术)
    Java Web 网络留言板
    Java Web JDBC数据源
    Java Web CommonsUtils (数据库连接方法)
    Servlet 起源
    Hibernate EntityManager
    Hibernate Annotation (Hibernate 注解)
    wpf控件设计时支持(1)
  • 原文地址:https://www.cnblogs.com/magic-girl/p/pytorch_template.html
Copyright © 2011-2022 走看看