zoukankan      html  css  js  c++  java
  • Feature Distillation With Guided Adversarial Contrastive Learning

    Bai T., Chen J., Zhao J., Wen B., Jiang X., Kot A. Feature Distillation With Guided Adversarial Contrastive Learning. arXiv preprint arXiv 2009.09922, 2020.

    本文是通过固定教师网络(具有鲁棒性), 让学生网络去学习教师网络的鲁棒特征. 相较于一般的distillation 方法, 本文新加了reweight机制, 另外其损失函数非一般的交叉熵, 而是最近流行的对比损失.

    主要内容

    在这里插入图片描述

    本文的思想是利用robust的教师网络(f^t)来辅助训练学生网络(f^s), 假设有输入((x, y)), 通过网络得到特征

    [t^+:= f^t(x), s^+:=f^s(x), ]

    ((t^+, s^+))构成正样本对, 自然我们需要学生网络提取的特征(s^+)能够逼近(t^+), 进一步, 构建负样本对, 采样样本({x_1^-, x_2^-, ldots, x_k^- }), 同时得到负样本对((t^+,s_i^-)), 其中(s_i^-=f^s(x_i^-)). 总的样本对就是

    [mathcal{S}_{pair} := {(t^+, s^+), (t^+, s_1^-), ldots, (t^+, s_k^-)}. ]

    根据负样本采样的损失, 最大化

    [J( heta):= mathbb{E}_{(t,s)sim p(t,s)} log P(1|t,s; heta) + mathbb{E}_{(t,s)sim q(t,s)} log P(0|t,s; heta). ]

    当然对于本文的问题需要特殊化, 既然先验(P(C=1)=frac{1}{k+1}, P(C=0)=frac{k}{k+1}), 故

    [J( heta):= mathbb{E}_{(t,s)sim p(t,s)} log P(1|t,s; heta) + kcdot mathbb{E}_{(t,s)sim q(t,s)} log P(0|t,s; heta). ]

    (q(t,s))是一个区别于(p(t,s))的分布, 本文采用了(p(t)q(s)).

    作者进一步对前一项加了解释

    [egin{array}{ll} P(1|t,s; heta) &= frac{P(t,s)P(C=1)}{P(t,s)P(C=1) + P(t)P(s)P(C=0)} \ &le frac{P(t,s)}{kcdot P(t)P(s)}, \ end{array} ]

    [mathbb{E}_{(t,s)sim p(t,s)} log P(1|t,s; heta) + log kle I(t,s). ]

    (J( heta))的第二项是负的, 故

    [J( heta) le I(t,s), ]

    所以最大化(J( heta))能够一定程度上最大化(t,s)的互信息.

    reweight

    教师网络一般要求精度(干净数据集上的准确率)比较高, 但是通过对抗训练所生成的教师网络往往并不具有这一特点, 所以作者采取的做法是, 对特征(t)根据其置信度来加权(w), 最后损失为

    [mathcal{L}( heta) := mathbb{E}_{(t,s)sim p(t,s)} w_t log P(1|t,s; heta) + kcdot mathbb{E}_{(t,s)sim p(t)p(s)} w_t log P(0|t,s; heta), ]

    其中

    [w_t leftarrow p_{ypred=y}(f^t,t^+) in [0, 1]. ]

    (w_t)为教师网络判断(t^+)类别为(y)(真实类别)的概率.

    拟合概率(P(1|t,s; heta))

    在负采样中, 这类概率是直接用逻辑斯蒂回归做的, 本文采用

    [P(1|t,s; heta) = h(t,s) = frac{e^{t^Ts/ au}}{e^{t^Ts/ au}+frac{k}{M}}, ]

    其中(M)为数据集的样本个数.
    会不会

    [frac{e^{t^Ts/ au}}{e^{t^Ts/ au}+gamma cdot frac{k}{M^2}}, ]

    (gamma)也作为一个参数训练符合NCE呢?

    实验的细节

    文中有如此一段话

    we sample negatives from different classes rather than different instances, when picking up a positive sample from the same class.

    也就是说在实际实验中, (t^+,s^+)对应的类别是同一类的, (t^+, s^-)对应的类别不是同一类的.

    In our view, adversarial examples are like hard examples supporting the decision boundaries. Without hard examples, the distilled models would certainly make mistakes. Thus, we adopt a self-supervised way to generate adversarial examples using Projected Gradient Descent (PGD).

    也就是说, (t, s)都是对抗样本?

    超参数: (k=16384), ( au=0.1).

    疑问

    算法中的采样都是针对单个样本的, 但是我想实际训练的时候应该还是batch的, 不然太慢了, 但是如果是batch的话, 怎么采样呢?

  • 相关阅读:
    重启宝塔面板后提示-ModuleNotFoundError: No module named 'geventwebsocket'
    浅谈自动化
    【测试基础】App测试要点总结
    记录python上传文件的坑(2)
    使用navicat连接只开放内网ip连接的数据库
    【测试基础】数据库索引
    记录python上传文件的坑(1)
    使用docker-compose安装wordpress
    2-2 远程管理命令-网卡和IP地址的概念
    2-1. 远程管理常用命令-关机和启动
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/13772062.html
Copyright © 2011-2022 走看看