zoukankan      html  css  js  c++  java
  • 【猫狗数据集】从命令行接收参数

    数据集下载地址:

    链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
    提取码:2xq4

    创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html

    读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html

    进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html

    保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html

    加载保存的模型并测试:https://www.cnblogs.com/xiximayou/p/12459499.html

    划分验证集并边训练边验证:https://www.cnblogs.com/xiximayou/p/12464738.html

    使用学习率衰减策略并边训练边测试:https://www.cnblogs.com/xiximayou/p/12468010.html

    利用tensorboard可视化训练和测试过程:https://www.cnblogs.com/xiximayou/p/12482573.html

    epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html

    本节我们要在命令行接收参数,包括batch_size的值以及网络的类型。

    基本上我们只需要修改main.py就行了:

    main.py

    import sys
    sys.path.append("/content/drive/My Drive/colab notebooks")
    from utils import rdata
    from model import resnet
    import torch.nn as nn
    import torch
    import numpy as np
    import torchvision
    import train
    import torch.optim as optim
    
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    
    def main(batch_size,baseline):
      train_loader,val_loader,test_loader=rdata.load_dataset(batch_size)
      if baseline:
        model =torchvision.models.resnet18(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features,2,bias=False)
      if torch.cuda.is_available():
        model.cuda()
    
      #定义训练的epochs
      num_epochs=100
      #定义学习率
      learning_rate=0.1
      #定义损失函数
      criterion=nn.CrossEntropyLoss()
      #定义优化方法,简单起见,就是用带动量的随机梯度下降
      optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, momentum=0.9,
                                weight_decay=1*1e-4)
      scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [40,80], 0.1)
      print("训练集有:",len(train_loader.dataset))
      #print("验证集有:",len(val_loader.dataset))
      print("测试集有:",len(test_loader.dataset))
      trainer=train.Trainer(criterion,optimizer,model)
      trainer.loop(num_epochs,train_loader,val_loader,test_loader,scheduler)
    
    if __name__ == "__main__":
      import argparse
      p=argparse.ArgumentParser()
      p.add_argument("--batch_size",type=int,default=64)
      p.add_argument("--baseline",action="store_true")
      args=p.parse_args()
      main(args.batch_size,args.baseline)

    说明:我们将读取数据集、定义损失、优化器等代码放入到main()函数中,然后给main传入batch_size和baseline。使用argparse可以从命令行接收参数。add_argument()函数中,第一个参数是参数的名称,第二个是参数的类型,default是默认值,即不在命令行输入--batch_size 具体值,则会使用默认值。需要关注的是action="store_true",该参数的意思是默认baseline为False,如果在命令行中加入了--baseline,则baseline的值就为True。

    结果如图所示:

    没有加--batch_size,则batch_size默认为64,也就是18255/64约等于286。然后我们使用了--baseline,即默认使用resnet18模型。

    由于图像分类一般考虑的衡量指标是top1和top5,下一节就是加上计算top5的代码了。

  • 相关阅读:
    Reinforcement Learning Qlearning 算法学习3 AI
    2012年末工作中遇到的问题总结及感悟
    对JMS的一些认识
    readme
    数据库表扩展字段设计思路
    对网络安全性和apache shiro的一些认识
    JDK版本的了解
    下拉框“数据字典”设计
    缓存学习总结
    apache commons包简介
  • 原文地址:https://www.cnblogs.com/xiximayou/p/12488662.html
Copyright © 2011-2022 走看看