zoukankan      html  css  js  c++  java
  • pytorch进行图像分类的流程,下一篇为实例源代码解析

    一、预处理部分

    1.拿到数据首先对数据进行分析

    对数据的分布有一个大致的了解,可以用画图函数查看所有类的分布情况。可以采取删除不合理类的方法来提高准确率;

    对图像进行分析,在自定义的图像增强的多种方式中,尝试对图像进行变换,看是否存在主观上的特征增强,具体的增强

    方法在aug.py文件中,可以在线下对数据进行测试,看是否在增强后对结果有好的影响。

    2.模型的选取

    依据新模型效果较好的原则,尽量选取已存在的最新模型,可以选取进几年再imagenet比赛上取得最好的效果的几种模型

    分别进行测试,目前效果最好的模型是resnet(深度残差网络),是卷积神经网络的最新发展;

    但仅仅单模型的效果肯定是不如多模型综合的效果好的,所以可以选取效果较好的几种模型,最后按其权重进行加权平均

    来获取最终的预测结果;

    始终要注意的一点是,模型是次要的,最主要最核心的问题还是在于对于数据的处理。

    3.处理数据

    对数据图像进行增强,不管是使用pytorch自带的transform模块,还是自定义的数据增强处理方式,都要对数据进行合理的

    改变,最基本的改变是对图像进行简单的随机翻转、切割、旋转等,还有要注意的一点是需要改变图像的尺寸,以适应模型

    的输入要求。

    本次比赛数据进行的增强方式有:

    • RandomRotation(30)
    • RandomHorizontalFlip()
    • RandomVerticalFlip()
    • RandomAffine(45)

    4.超参数的设置

    对于整体代码中所需要的超参数进行单独处理,设置在一个文件中,使用时候直接调用即可。

    二、输入数据进入模型进行训练

    1.划分数据集

    首先根据所给文件把每个类的图像都分类到各自的文件夹中去,模型的输入要求类型基本都是这样,然后对于数据集划分为

    训练集、测试集、验证集,分别在模型的训练、测试阶段使用。

    2.模型训练

    根据pytorch的模型训练过程,输入训练集,对模型进行训练,每个epoch后对模型进行评价,在整个epoch结束后,得到最好

    的模型。

    3.测试阶段

    把测试集输入保存的最好模型中去,得到输出结果,进行分析。

    三、pytorch中的训练模块化

    1.加载模型

    2.优化器和loss函数的设置

    3.训练集加载入pytorch的数据加载类Dataloader中,以便于调用

    4.开始每个epoch的训练,输入,目标,loss,归零,反向传播,开始

    5.评估模型,得出最优模型

    参考大神chaojiezhu的github。

    https://github.com/spytensor/plants_disease_detection

  • 相关阅读:
    SQL后台分页三种方案和分析
    SQL分页查询语句
    SQL利用临时表实现动态列、动态添加列
    查询sybase DB中占用空间最多的前20张表
    敏捷软件开发之TDD(一)
    敏捷软件开发之开篇
    Sql Server 2012启动存储过程
    改变VS2013的菜单栏字母为小写
    Sql Server获得每个表的行数
    Sql Server trace flags
  • 原文地址:https://www.cnblogs.com/ywheunji/p/10127951.html
Copyright © 2011-2022 走看看