zoukankan      html  css  js  c++  java
  • 深度学习的模型是怎么训练/优化出来的

    以典型的分类问题为例,来梳理模型的训练过程。训练的过程就是问题发现的过程,一次训练是为下一步迭代做好指引。

    1.数据准备

    准备:

    • 数据标注前的标签体系设定要合理
    • 用于标注的数据集需要无偏、全面、尽可能均衡
    • 标注过程要审核

    整理数据集

    1. 将各个标签的数据放于不同的文件夹中,并统计各个标签的数目
      如:第一列是路径,最后一列是图片数目。

      PS:可能会存在某些标签样本很少/多,记下来模型效果不好就怨它。

    2. 样本均衡,样本不会绝对均衡,差不多就行了
      如:控制最大类/最小类<(delta)(delta=5),最后一列为均衡的目标值。

    3. 切分样本集
      如:90%用于训练,10%留着测试,比例自己定。训练集合,对于弱势类要重采样,最后的图片列表要shuffle;测试集合就不用重采样了。
      训练中要保证样本均衡,学习到弱势类的特征,测试过程要反应真实的数据集分布。
      第一列是图片路径,后面几列是标签(多任务)。

    4. 按需要的格式生成tfrecord
      按照train.list和validation.list生成需要的格式。生成和解析tfrecord的代码要根据具体情况编写。

    2.训练

    • 预处理,根据自己的喜好,编写预处理策略。
      preprocessing的方法,变换方案诸如:随机裁剪、随机变换框、添加光照饱和度、修改压缩系数、各种缩放方案、多尺度等。进而,减均值除方差或归一化到[-1,1],将float类型的Tensor送入网络。
      这一步的目的是:让网络接受的训练样本尽可能多样,不要最后出现原图没问题,改改分辨率或宽高比就跪了的情况。
    • 网络设计,基础网络的选择和Loss的设计。
      基础网络的选择和问题的复杂程度息息相关,用ResNet18可以解决的没必要用101;还有一些SE、GN等模块加上去有没有提升也可以去尝试。
      Loss的设计,一般问题的抽象就是设计Loss数据公式的过程。比如多任务中的各个任务权重配比,centorLoss可以让特征分布更紧凑,SmoothL1Loss更平滑避免梯度爆炸等。
    • 优化算法
      一般来说,只要时间足够,Adam和SGD+Momentum可以达到的效果差异不大。用框架提供的理论上最好的优化策略就是了。
    • 训练过程
      finetune网络,我习惯分两步:首先训练fc层,迭代几个epoch后保存模型;然后基于得到的模型,训练整个网络,一般迭代40-60个epoch可以得到稳定的结果。

      total_loss会一直下降的,过程中可以评测下模型在测试集上的表现。真正的loss往往包括两部分。后面total_loss的下降主要是正则项的功劳了。

    3.评估模型

    1.混淆矩阵必不可少
    混淆矩阵可以发现哪些类是难区分的。基于混淆矩阵可以得到各类的准召,进而可以得到哪些类比较差。
    如:列为真值,行为检测的值。

    gt/pl 靴子 单鞋 运动 休闲 棉鞋 雪地靴 帆布 拖鞋 凉鞋 雨鞋
    靴子 4524 45 39 79 12 59 5 6 0 20
    单鞋 51 4088 15 44 115 9 18 80 43 6
    运动 38 6 817 247 0 2 18 8 1 0
    休闲 53 47 171 806 17 8 118 15 1 2
    棉鞋 12 110 5 15 424 55 2 32 1 1
    雪地靴 53 6 5 10 73 628 0 13 2 1
    帆布鞋 5 28 16 158 1 1 515 17 3 4
    拖鞋 6 139 1 12 33 3 18 2316 60 6
    凉鞋 7 69 3 6 0 0 2 55 633 1
    雨鞋 26 6 1 3 0 1 2 5 1 499

    进而可得:

    label 召回 精度
    靴子 0.9446648569638756 0.947434554973822
    单鞋 0.9147460281942269 0.8996478873239436
    运动 0.7185576077396658 0.7614165890027959
    休闲 0.6510500807754442 0.5840579710144927
    ... ... ...

    PS:运动-休闲容易混淆。

    2.抽样看测试数据
    从测试数据中每类抽1000张,把它们的模型结果放在不同的文件夹下。对于分析问题还是很有效的,为什么它会分错,要拿出来看看!
    有些确实是人工标错了。

    3.CAM
    通过CAM可以查看网络究竟学到了什么(是不是学错了)。对于细粒度问题就不用分析CAM了,一般7x7的特征图本来就很小了,根本就看不出细节学到了什么,只能粗略看看部位定位是否准确。

    也可以一定程度上帮助理解为什么网络会搞错,比如下面的单鞋被误判为了拖鞋。

  • 相关阅读:
    批量更新sql |批量update sql
    智力测试题3
    【管理心得之二十一】管得少就是管得好
    查看sqlserver被锁的表以及如何解锁
    AD域相关的属性和C#操作AD域
    毕业5年小结一下
    WPF版公司的自动签到程序
    用友畅捷通高级前端笔试题(一)凭借回忆写出
    .NET中制做对象的副本(三)通过序列化和反序列化为复杂对象制作副本
    .NET中制做对象的副本(二)继承对象之间的数据拷贝
  • 原文地址:https://www.cnblogs.com/houkai/p/10221709.html
Copyright © 2011-2022 走看看