zoukankan      html  css  js  c++  java
  • improved open set domain adaptation with backpropagation 学习笔记

    improved open set domain adaptation with backpropagation 学习笔记

    @

    TIP

    KL距离,是Kullback-Leibler差异(Kullback-Leibler Divergence)的简称,也叫做相对熵(Relative Entropy)。它衡量的是相同事件空间里的两个概率分布的差异情况
    本文更改了OpenBP中的二元交叉熵损失,提高了识别率。

    ABSTRACT

    本文对Open set domain adaptation by back propagation(OSDA-BP)中用于提取潜在未知类别样本的二元交叉熵损失进行了深入的研究。基于这种新的理解,我们提出用对称的库勒贝克-莱布勒距离损失来代替二元交叉熵损失

    作者透彻详尽地解释了对于OSDA-BP中二元交叉熵损失的理解,并使用对称的KL距离来提出一个新的二元交叉熵损失公式

    2.PROPOSED METHOD

    2.1 Overall Idea

    本文的方法框架主要还是基于论文《Open set domain adaptation by backpropagation》中的框架。该方法的框图为:

    image-20201108213044235

    其中源域(D_s={(x^s_i,y^s_i)}^{n_s}_{i=1})拥有(n_s)个已标注的样本,而目标域(D_t={(x^t_i,y^t_i)}^{n_t}_{i=1})拥有(n_t)个未标注的样本,其中x表示样本的图像,y表示样本相对应的标签。源域与目标域之间都存在彼此未拥有的类别。在这样的设定下,作者基于CNN训练网络(f( heta,x)),来将输入的样本(x_s)或者(x_t)分类成K+1类,其中K表示已知类的个数,第K+1类表示未知类。即(f( heta,x)={P(cls(x)=1...P(cls(x)=K+1))}).

    模型使用了一个特征提取器与一个分类器,其中(f( heta,x)=C(G( heta_g,x), heta_c))。(( heta_g)表示特征提取器的参数,而( heta_c)表示分类器的参数)

    在OpenBP中,首先使用标准交叉熵损失(L_s)来进行源域样本的分类:

    (L_s( heta,D_s)=frac{1}{|D_s|}sum limits_{(x_s,y_s)in D_s}l(y_s,f( heta,x_s))),其中的(l(y,f)=-sum limits_{j=1}limits^{K}y_jlog(f_j)),(|D_s|)表示源域样本的个数。

    接着OpenBP使用一个二元交叉熵损失(L_u)训练分类器来形成目标域中已知类与未知类之间的边界:(L_u( heta,x_t)=-(1-t)(1-log(P(cls(x_t)=K+1)))-tlog(P(cls(x_t)=K+1))),t的值为0.5.

    为了将目标域中未知类别的样本分离,我们还可以使用二元交叉熵损失的平均形式(L_u( heta,D_t)=frac{1}{|D_t|}sumlimits_{x_tin D_t}L_u( heta,x_t)).

    使用(p_t=(t,1-t))来表示一个由t(0<t<1)参数化的二元分布,对于任何目标域样本(x_t),令(hat{t} riangleq P(cls(x_t)=K+1)),且(p_{hat{t}}=(hat{t},1-hat{t}))。则(p_t)(p_{hat{t}})之间的KL距离为

    (d_{KL}(p_t||p_{hat{t}})=tlogfrac{t}{hat{t}}+(1-t)logfrac{1-t}{1-hat{t}}=-tloghat{t}-(1-t)log(1-hat{t})+v(t)).

    其中(v(t)=tlogt+(1-t)log(1-t))对于一个固定的t来说是一个固定的值。

    上面的二元交叉熵损失(L_u)可以看作为((t,1-t))((p(cls(x_t)=K+1),1-p(cls(x_t)=K+1))之间除去常数(v(t))的KL距离,通过设置t = 0.5,它为训练好的网络提供了一个合理的机制来区分已知类和未知类。

    由于二元交叉熵损失本质上是一个KL距离,我们可以进一步利用它的对称形式:

    (L_{adv}( heta,t,D_t)=frac{1}{|D_t|}sumlimits_{x_tin D_t}L_{adv}( heta,t,x_t))

    (L_{adv}( heta,t,x_t)=d_{KL}(p_t||p_{hat{t}(x_t)})+d_{KL}(p_{hat{t}(x_t)}||p_t)).

    整理之后,总的损失为:

    (L( heta,t)=L_s( heta,D_s)+lambda_1L_{adv}( heta,t,D_t)),(lambda_1)=0.5

    总的目标函数为:

    (minlimits_{ heta_c}L_s( heta,D_s)+lambda_1L_{adv}( heta,t,D_t)).

    (minlimits_{ heta_g}L_s( heta,D_s)-lambda_1L_{adv}( heta,t,D_t)).

  • 相关阅读:
    django操作mysql时django.db.utils.OperationalError: (2003, "Can't connect to MySQL server")异常的解决方法
    Django实践:个人博客系统(第七章 Model的设计和使用)
    shared_ptr / weak_ptr 代码片段
    Java中比较容易混淆的知识点
    指针和引用作为参数的区别
    STL 算法
    STL扩展容器
    STL中 map 和 multimap
    STL中 set 和 multiset
    <<C++标准程序库>>中的STL简单学习笔记
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13948609.html
Copyright © 2011-2022 走看看