zoukankan      html  css  js  c++  java
  • Siamese网络

    1.       对比损失函数(Contrastive Loss function)

    孪生架构的目的不是对输入图像进行分类,而是区分它们。因此,分类损失函数(如交叉熵)不是最合适的选择,这种架构更适合使用对比函数。对比损失函数如下:

     

    (以判断图片相似度为例)其中Dw被定义为姐妹孪生网络的输出之间的欧氏距离。Y值为1或0。如果模型预测输入是相似的,那么Y的值为0,否则Y为1。m是大于0的边际价值(margin value)。有一个边际价值表示超出该边际价值的不同对不会造成损失。

    Siamese网络架构需要一个输入对,以及标签(类似/不相似)。

    2.       孪生网络的训练过程

    (1)    通过网络传递图像对的第一张图像。

    (2)    通过网络传递图像对的第二张图像。

    (3)    使用(1)和(2)中的输出来计算损失。

     

      其中,l12为标签,用于表示x1的排名是否高于x2。

    训练过程中两个分支网络的输出为高级特征,可以视为quality score。在训练时,输入是两张图像,分别得到对应的分数,将分数的差异嵌入loss层,再进行反向传播。

    (4)    返回传播损失计算梯度。

    (5)    使用优化器更新权重。

    3.       基于Siamese网络的无参考图像质量评估:RankIQA

    3.1          参考文献

    https://arxiv.org/abs/1707.08347

    3.2          RankIQA的流程

    (1)    合成失真图像。

    (2)    训练Siamese网络,使网络输出对图像质量的排序。

    (3)    提取Siamese网络的一支,使用IQA数据集进行fine-tune。将网络的输出校正为IQA度量值。fine-tune阶段的损失函数如下:

     

    训练阶段使用Hinge loss,fine-tune阶段使用MSE。

    训练时,每次从图像中随机提取224*224或者227*227大小的图像块。和AlexNet、VGG16有关。在训练Siamese network时,初始的learning rate 是1e-4;fine-tuning是初始的learning rate是1e-6,每隔10k步,rate变成原来的0.1倍。训练50k次。测试时,随机在图像中提取30个图像块,得到30个分数之后,进行均值操作。

    本文如何提高Siamese网络的效率:

    假设有三张图片,则每张图片将被输入网络两次,原因是含有某张图片的排列数为2。为了减少计算量,每张图片只输入网络一次,在网络之后、损失函数之前新建一层,用于生成每个mini-batch中图片的可能排列组合。

    使用上述方法,每张图片只前向传播一次,只在loss计算时考虑所有的图片组合方式。

    本文使用的网络架构:Shallow, AlexNet, and VGG-16。

    4.       Siamese网络的开源实现

    4.1          代码地址

    https://github.com/xialeiliu/RankIQA

    4.2          RankIQA的运行过程

    4.2.1            数据集

    使用两方面的数据集,一般性的非IQA数据集用于生成排序好的图片,进而训练Siamese网络;IQA数据集用于微调和评估。

    本文使用的IQA数据集:

    (1)    LIVE数据集:http://live.ece.utexas.edu/research/quality/ 对29张原始图片进行五类失真处理,得到808张图片。Ground Truth MOS在[0, 100]之间(人工评分)。

    (2)    TID2013:25张原始图片,3000张失真图片。MOS范围是[0, 9]。

    本文使用的用于生成ranked pairs的数据集:

    (1)    为了测试LIVE数据集,人工生成了四类失真,GB(Gaussian Blur)、GN(Gaussian Noise)、JPEG、JPEG2K

    (2)    为了在TID2013上测试,生成了17种失真(去掉了#3, #4,#12, #13, #20, #21, #24)

    Waterloo数据集:

    包含4744张高质量自然图片。

    Places2数据集:

    作为验证集(包含356种场景,http://places2.csail.mit.edu/ ),每类100张,共35600张。

    两种数据集的区别:

    python generate_rank_txt_tid2013.py生成的是tid2013_train.txt,标签只起到表示相对顺序的作用,即,标签为{1, 2, 3, 4, 5};python generate_ft_txt_tid2013.py生成的是ft_tid2013_test.txt,其中的标签是浮点数,表示图片的质量评分。

    4.2.2            训练和测试过程

    从原始图像中随机采样子图(sub-images),避免因差值和过滤而产生的失真。输入的子图至少占原图的1/3,以保留场景信息。本文采用227*227或者224*224的采样图像(根据使用的主干网络而不同)。

    训练过程使用mini-batch SGD,初始学习率1e-4,fine-tune学习率1e-6。

    共迭代50K次,每10K次减小学习率(乘以0.1),两个训练过程都是用l2权重衰减(正则化系数lambda=5e-4)。

    实验一:本文首先使用Places2数据集(使用五种失真进行处理)训练网络(不进行微调),然后在Waterloo数据及上进行预测IQA(使用同样的五种失真进行处理)。实验结果如图2所示。

     

    实验二:hard negative mining

    难分样本挖掘,是当得到错误的检测patch时,会明确的从这个patch中创建一个负样本,并把这个负样本添加到训练集中去。重新训练分类器后,分类器会表现的更好,并且不会像之前那样产生多的错误的正样本。

    本实验使用Alexnet进行。

    实验三:网络性能分析

    LIVE数据集,80%训练集,评价指标LCC和SROCC。VGG-16的效果最好。

    4.2.3            RankIQA对数据集的处理过程

    将原始图像文件放在data/rank_tid2013/pristine_images路径下,然后运行data/rank_tid2013/路径下的tid2013_main.m,进而生成排序数据集(17种失真形式)。

    4.3          运行指令

    4.3.1            Train RankIQA

    To train the RankIQA models on tid2013 dataset:

    ./src/RankIQA/tid2013/train_vgg.sh

    To train the RankIQA models on LIVE dataset:

    ./src/RankIQA/live/train_vgg.sh

    FT

    To train the RankIQA+FT models on tid2013 dataset:

    ./src/FT/tid2013/train_vgg.sh

    To train the RankIQA+FT models on LIVE dataset:

    ./src/FT/live/train_live.sh

    4.3.2            Evaluation for RankIQA

    python src/eval/Rank_eval_each_tid2013.py  # evaluation for each distortions in tid2013

    python src/eval/Rank_eval_all_tid2013.py   # evaluation for all distortions in tid2013

    Evaluation for RankIQA+FT on tid2013:

    python src/eval/FT_eval_each_tid2013.py  # evaluation for each distortions in tid2013

    python src/eval/FT_eval_all_tid2013.py   # evaluation for all distortions in tid2013

    Evaluation for RankIQA on LIVE:

    python src/eval/Rank_eval_all_live.py   # evaluation for all distortions in LIVE

    Evaluation for RankIQA+FT on LIVE:

    python src/eval/FT_eval_all_live.py   # evaluation for all distortions in LIVE

    5.       代码调试过程

    5.1          Python无法导入某个模块ImportError:could not find module XXX

    解决方案:

    配置环境变量:export PYTHONPATH=path/to/modules

  • 相关阅读:
    我对自己公司产品的看法与一点微不足道的建议
    Error:java: 无效的源发行版: 1.8
    生成带星期的日期格式
    使用RestTemplate发送multipart/form-data格式的数据
    解决java.lang.NoClassDefFoundError错误
    Invalid bound statement (not found) 问题处理
    java8 关于日期的处理
    关于java后台如何接收xml格式的数据
    关于线程和junit注入失败的问题
    多线程异步调度任务
  • 原文地址:https://www.cnblogs.com/sddai/p/10613429.html
Copyright © 2011-2022 走看看