zoukankan      html  css  js  c++  java
  • 3 基于梯度的攻击——MIM

    MIM攻击原论文地址——https://arxiv.org/pdf/1710.06081.pdf

    1.MIM攻击的原理

    MIM攻击全称是 Momentum Iterative Method,其实这也是一种类似于PGD的基于梯度的迭代攻击算法。它的本质就是,在进行迭代的时候,每一轮的扰动不仅与当前的梯度方向有关,还与之前算出来的梯度方向相关。其中的衰减因子就是用来调节相关度的,decay_factor在(0,1)之间,decay_factor越小,迭代轮数靠前算出来的梯度对当前的梯度方向影响越小。由于之前的梯度对后面的迭代也有影响,迭代的方向不会跑偏,总体的大方向是对的。

    为了加速梯度下降,通过累积损失函数的梯度方向上的矢量,从而(1)稳定更新(2)有助于通过 narrow valleys, small humps and poor local minima or maxima.(大致意思就是,可以有效避免局部最优)

    是decay_factor, 另外,在原论文中,每一次迭代对x的导数是直接算的1-范数,然后求平均,但在各个算法库以及论文实现的补充中,并没有求平均,估计这个对结果影响不太大。

    2.代码实现

    class MomentumIterativeAttack(Attack, LabelMixin):
        """
        The L-inf projected gradient descent attack (Dong et al. 2017).
        The attack performs nb_iter steps of size eps_iter, while always staying
        within eps from the initial point. The optimization is performed with
        momentum.
        Paper: https://arxiv.org/pdf/1710.06081.pdf
        """
    
        def __init__(
                self, predict, loss_fn=None, eps=0.3, nb_iter=40, decay_factor=1.,
                eps_iter=0.01, clip_min=0., clip_max=1., targeted=False):
            """
            Create an instance of the MomentumIterativeAttack.
    
            :param predict: forward pass function.
            :param loss_fn: loss function.
            :param eps: maximum distortion.
            :param nb_iter: number of iterations
            :param decay_factor: momentum decay factor.
            :param eps_iter: attack step size.
            :param clip_min: mininum value per input dimension.
            :param clip_max: maximum value per input dimension.
            :param targeted: if the attack is targeted.
            """
            super(MomentumIterativeAttack, self).__init__(
                predict, loss_fn, clip_min, clip_max)
            self.eps = eps
            self.nb_iter = nb_iter
            self.decay_factor = decay_factor
            self.eps_iter = eps_iter
            self.targeted = targeted
            if self.loss_fn is None:
                self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
    
        def perturb(self, x, y=None):
            """
            Given examples (x, y), returns their adversarial counterparts with
            an attack length of eps.
    
            :param x: input tensor.
            :param y: label tensor.
                      - if None and self.targeted=False, compute y as predicted
                        labels.
                      - if self.targeted=True, then y must be the targeted labels.
            :return: tensor containing perturbed inputs.
            """
            x, y = self._verify_and_process_inputs(x, y)
    
            delta = torch.zeros_like(x)
            g = torch.zeros_like(x)
    
            delta = nn.Parameter(delta)
    
            for i in range(self.nb_iter):
    
                if delta.grad is not None:
                    delta.grad.detach_()
                    delta.grad.zero_()
    
                imgadv = x + delta
                outputs = self.predict(imgadv)
                loss = self.loss_fn(outputs, y)
                if self.targeted:
                    loss = -loss
                loss.backward()
    
                g = self.decay_factor * g + normalize_by_pnorm(
                    delta.grad.data, p=1)
                # according to the paper it should be .sum(), but in their
                #   implementations (both cleverhans and the link from the paper)
                #   it is .mean(), but actually it shouldn't matter
    
                delta.data += self.eps_iter * torch.sign(g)
                # delta.data += self.eps / self.nb_iter * torch.sign(g)
    
                delta.data = clamp(
                    delta.data, min=-self.eps, max=self.eps)
                delta.data = clamp(
                    x + delta.data, min=self.clip_min, max=self.clip_max) - x
    
            rval = x + delta.data
            return rval
    

      

    有人认为,advertorch中在迭代过程中,应该是对imgadv求导,而不是对delta求导,foolbox和cleverhans的实现都是对每一轮的对抗样本求导。

  • 相关阅读:
    线程的故事:我的3位母亲成就了优秀的我!
    Semaphore自白:限流器用我就对了!
    CyclicBarrier:人齐了,老司机就可以发车了!
    最新版Swagger 3升级指南和新功能体验!
    阿里巴巴Druid,轻松实现MySQL数据库连接加密!
    try-catch-finally中的4个大坑,不小心就栽进去了!
    Git 常用命令总结,将会持续更新
    oracle in 条件超长问题解决
    关于java中使用split方法末尾空值被丢弃的问题
    Ubuntu 嵌入式开发准备
  • 原文地址:https://www.cnblogs.com/shona/p/11274443.html
Copyright © 2011-2022 走看看