zoukankan      html  css  js  c++  java
  • Universal adversarial perturbations

    Moosavidezfooli S, Fawzi A, Fawzi O, et al. Universal Adversarial Perturbations[C]. computer vision and pattern recognition, 2017: 86-94.

    @article{moosavidezfooli2017universal,
    title={Universal Adversarial Perturbations},
    author={Moosavidezfooli, Seyedmohsen and Fawzi, Alhussein and Fawzi, Omar and Frossard, Pascal},
    pages={86--94},
    year={2017}}

    深度学习的脆弱以及周知了, 但是到底脆弱到何种程度, 本文作者发现, 可以找到一种对抗摄动, 令其作用在不同的数据上, 结果大部分都会被攻破(即被网络误判). 甚至, 这种对抗摄动可以跨越网络结构的障碍, 即在这个网络结构上的一般性的对抗摄动, 也能有效地攻击别的网络.

    主要内容

    一般地对抗样本, 是针对特定的网络(hat{k}), 特定的样本(x_i in mathbb{R}^d), 期望找到摄动(v_i in mathbb{R}^d), 使得

    [v_i= arg min_r |r|_p : mathrm{s.t.} : hat{k}(x_i+r) ot = hat{k}(x_i) . ]

    而本文的通用摄动(universal perturbations)是希望找到一个(v, |v|_p le xi), 且

    [mathbb{P}_{x sim mu} (hat{k}(x+v) ot = hat{k}(x)) ge 1-delta, ]

    其中(mu)为数据的分布.

    算法

    构造这样的(v)的算法如下:
    在这里插入图片描述

    其中

    [mathcal{P}_{p, xi} (v)= arg min_{v'} |v'-v|_2, : mathrm{s.t.} |v'|_p le xi, ]

    为向(p)范数球的投影.

    实验部分

    实验1

    实验1, 在训练数据集(ILSVRC 2012)(X)上(摄动是在此数据上计算得到的), 以及在验证数据集上, 攻击不同的网络.
    在这里插入图片描述

    实验2

    实验2测试这种通用摄动的网络结构的转移性, 即在一个网络上寻找摄动, 在其它模型上计算此摄动的攻击成功率. 可见, 这种通用摄动的可迁移性是很强的.
    在这里插入图片描述

    实验3

    实验3, 研究了样本个数对攻击成功率的一个影响, 可以发现, 即便我们用来生成摄动的样本个数只有500(少于类别个数1000)都能有不错的成功率.
    在这里插入图片描述

    代码

    代码因为还有用到了别的模块, 这放在这里看看, 论文有自己的代码.

    
    import torch
    import logging
    from configs.adversarial.universal_attack_cfg import cfg
    
    
    sub_logger = logging.getLogger("__main__.__submodule__")
    
    class AttackUni:
    
        def __init__(self, net, device,
                     attack=cfg.attack, epsilon=cfg.epsilon, attack_cfg=cfg.attack_cfg,
                     max_iterations=cfg.max_iterations, fooling_rate=cfg.fooling_rate,
                     boxmin=0., boxmax=1.):
            """ the attack to construct universal perturbation
            :param net: the model
            :param device: may use gpu to train
            :param attack: default: PGDAttack
            :param epsilon: the epsilon to constraint the perturbation
            :param attack_cfg: the attack's config
            :param max_iterations: max_iterations for stopping early
            :param fooling_rate: the fooling rate we want
            :param boxmin: default: 0
            :param boxmax: default: 1
            """
    
            attack_cfg['net'] = net
            attack_cfg['device'] = device
            self.net = net
            self.device = device
            self.epsilon = epsilon
            self.attack = attack(**attack_cfg)
            self.max_iterations = max_iterations
            self.fooling_rate = fooling_rate
            self.boxmin = boxmin
            self.boxmax = boxmax
    
        def initialize_perturbation(self):
            self.perturbation = torch.tensor(0., device=self.device)
    
        def update_perturbation(self, perturbation):
            self.perturbation += perturbation
            self.perturbation = self.clip(self.perturbation).to(self.device)
    
        def clip(self, x):
            return torch.clamp(x, -self.epsilon, self.epsilon)
    
        def compare(self, x, label):
            x_adv = x + self.perturbation
            out = self.net(x_adv)
            pre = out.argmax(dim=1)
            return (pre == label).sum()
    
        def attack_one(self, img):
            result = self.attack.attack_batch(img+self.perturbation)
            perturbation = result['perturbations'][0]
            self.update_perturbation(perturbation)
    
        def attack_batch(self, dataloader):
            total = len(dataloader)
            self.initialize_perturbation()
            for epoch in range(self.max_iterations):
                count = 0
                for img, label in dataloader:
                    img = img.to(self.device)
                    label = img.to(self.device)
                    if self.compare(img, label):
                        self.attack_one(img)
                    else:
                        count += 1
                if count / total > self.fooling_rate:
                    break
                sub_logger.info("[epoch: {0:<3d}] 'fooling_rate': {1:<.6f}".format(
                    epoch, count / total
                ))
            return self.perturbation
    
    
    
  • 相关阅读:
    multipart/form-data
    Java面试之SE基础基本数据类型
    数据库中的悲观锁和乐观锁详解
    j2SE基回顾(一)
    Hibernate 检索查询的几种方式(HQL,QBC,本地SQL,集成Spring等)
    消防(bzoj 2282)
    YY的GCD(bzoj 2820)
    Problem b(bzoj 2301)
    完全平方数(bzoj 2440)
    The Luckiest number(hdu 2462)
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/13040274.html
Copyright © 2011-2022 走看看