zoukankan      html  css  js  c++  java
  • Adversarial Examples Improve Image Recognition

    Xie C, Tan M, Gong B, et al. Adversarial Examples Improve Image Recognition.[J]. arXiv: Computer Vision and Pattern Recognition, 2019.

    @article{xie2019adversarial,
    title={Adversarial Examples Improve Image Recognition.},
    author={Xie, Cihang and Tan, Mingxing and Gong, Boqing and Wang, Jiang and Yuille, Alan L and Le, Quoc V},
    journal={arXiv: Computer Vision and Pattern Recognition},
    year={2019}}

    为了让网络更稳定, 作者关注

    [arg min_{ heta} [mathbb{E}_{(x, y)sim mathbb{D}}(L( heta,x,y)+max_{epsilon in mathbb{S}}L( heta,x+epsilon,y)], ]

    实际上就是一种对抗训练. 但是如果只是普通的训练样本加上对应的adversarial samples混合训练效果并不好. 于是, 作者提出这种情况的原因是Batchnorm造成的, 只需要针对普通样本和对抗样本在训练的时候分别设置不同的batchnorm模块即可.
    在这里插入图片描述

    主要内容

    作者认为, 普通训练样本和对抗训练样本所属的分布不同, 此时用同一个batchnorm效果不好, 所以提出在训练的时候添加一个额外的batchnorm, 专门用于为对杨训练样本使用, 而在非训练截断, 只是用普通的Batchnorm.

    每一次训练步骤如下:

    1. 从普通训练样本中采样 batch (x^c)以及对应的标签(y);
    2. 根据相应算法(本文采用PGD)生成对抗训练样本(x^a)(额外的batchnorm);
    3. 计算损失(L^c( heta, x^c, y))(普通的batchnorm);
    4. 计算损失(L^a( heta, x^a, y))(额外的batchnorm);
    5. backward: (L^c( heta, x^c, y) + L^a( heta, x^a, y)), 并更新( heta).

    实验概述

    数据集: ImageNet-A, ImageNet-C, Stylized-ImageNet.

    5.2: AdvProp, 85.2% top-1 accuracy on ImageNet(Fig4);
    在打乱的ImageNet数据集合上测试(Table4, mCE(mean corruption , lower is better));
    探究adversarial attacks 强度对网络分类正确率的影响:当一个网络的“适应性"较弱的时候, 强度小反而效果好, ”适应性”较强的时候, 强度高更好(Table2);
    比较AdvProp与一般的对抗训练的效果差异(Fig5);
    “适应性”强的网络, AdvProp的作用越小;
    AutoAugment 与 Advprop的比较(Table 6);
    不同的adversarial attacks的影响(Table 7);

    代码

    代码未经测试.

    
    
    """
    white-box attacks:
    iFGSM
    PGD
    """
    
    
    import torch
    import torch.nn as nn
    
    class WhiteBox:
    
        def __init__(self, net, epsilon:float, times:int, criterion=None):
            self.net = net
            self.epsilon = epsilon
            self.times = times
            if not criterion:
                self.criterion = nn.CrossEntropyLoss()
            else:
                self.criterion = criterion
    
            pass
    
        @staticmethod
        def calc_jacobian(loss, inp):
            jacobian = torch.autograd.grad(loss, inp, retain_graph=True)[0]
            return jacobian
    
        @staticmethod
        def sgn(matrix):
            return torch.sign(matrix)
    
        @staticmethod
        def pre(out):
            return torch.argmax(out, dim=1)
    
        def fgsm(self, inp, y):
            inp.requires_grad_(True)
            out = self.net(inp)
            loss = self.criterion(out, y)
            delta = self.sgn(self.calc_jacobian(loss, inp))
            flag = False
            inp_new = inp.data
            for i in range(self.times):
                inp_new = inp_new + self.epsilon * delta
                out_new = self.net(inp_new)
                if self.pre(out_new) != y:
                    flag = True
                    break
            return flag, inp_new
    
        def ifgsm(self, inps, ys):
            N = len(inps)
            adversarial_samples = []
            for i in range(N):
                flag, inp_new = self.fgsm(
                    inps[[i]], ys[[i]]
                )
                if flag:
                    adversarial_samples.append(inp_new)
    
            return torch.cat(adversarial_samples), 
                   len(adversarial_samples) / N
    
        def pgd(self, inp, y, perturb):
            boundary_low = inp - perturb
            boundary_up = inp + perturb
            inp.requires_grad_(True)
            out = self.net(inp)
            loss = self.criterion(out, y)
            delta = self.sgn(self.calc_jacobian(loss, inp))
            flag = False
            inp_new = inp.data
            for i in range(self.times):
                inp_new = torch.clamp(
                    inp_new + delta,
                    boundary_low,
                    boundary_up
                )
                out_new = self.net(inp_new)
                if self.pre(out_new) != y:
                    flag = True
                    break
            return flag, inp_new
    
        def ipgd(self, inps, ys, perturb):
            N = len(inps)
            adversarial_samples = []
            for i in range(N):
                flag, inp_new = self.pgd(
                    inps[[i]], ys[[i]],
                    perturb
                )
                if flag:
                    adversarial_samples.append(inp_new)
    
            return torch.cat(adversarial_samples), 
                   len(adversarial_samples) / N
    
    
    
    
    
    """
    black-box attack
    see Practical Black-Box Attacks against Machine Learning.
    """
    
    
    
    import  torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader
    
    
    class Synthetic(Dataset):
    
        def __init__(self, data, labels):
            self.data = data
            self.labels = labels
    
        def __len__(self):
            return len(self.data)
    
        def __getitem__(self, index):
            return self.data[index], self.labels[index]
    
    
    
    
    
    class Blackbox:
    
        def __init__(self, oracle, substitute, data, trainer, lamb):
            self.oracle = oracle
            self.substitute = substitute
            self.data = []
            self.trainer = trainer
            self.lamb = lamb
            self.update(data)
    
    
        def update(self, data):
            labels = self.oracle(data)
            self.data.append(Synthetic(data, labels))
    
        def train(self):
            self.trainer(self.substitute, self.data, self.lamb)
    
    class Trainer:
    
        def __init__(self,
                     lr, weight_decay,
                     batch_size, shuffle=True, **kwargs):
            """
            :param lr:  learning rate
            :param weight_decay:
            :param batch_size: batch_size for dataloader
            :param shuffle:  shuffle for dataloader
            :param kwargs:  other configs for dataloader
            """
            self.kwargs = {"batch_size":batch_size,
                           "shuffle":shuffle}
            self.kwargs.update(kwargs)
            self.criterion = nn.CrossEntropyLoss
            self.opti = self.optim(lr=lr, weight_decay=weight_decay)
    
        @quireone
        def optim(self, parameters, **kwargs):
            """
            quireone is decorator defined below
            :param parameters: net.parameteres()
            :param kwargs: other configs
            :return:
            """
            return torch.optim.SGD(parameters, **kwargs)
    
        def dataloader(self, dataset):
            return DataLoader(dataset, **self.kwargs)
    
        @staticmethod
        def calc_jacobian(out, inp):
            jacobian = torch.autograd.grad(out, inp, retain_graph=True)[0]
            return jacobian
    
        @staticmethod
        def sgn(matrix):
            return torch.sign(matrix)
    
        def newdata(self, outs, inps, labels, lamb):
            data = inps.data
            for i in range(len(labels)):
                out = outs[i, labels[i]]
                data += lamb * self.sgn(self.calc_jacobian(out, inps))
    
            return data
    
        def train(self, net, criterion, opti, dataloader, lamb=None,
                 update=False):
            """
            :param net:
            :param criterion:
            :param opti:
            :param dataloader:
            :param lamb: lambda for update S
            :param update: if True, train will return the new data
            :return:
            """
            if update:
                assert lamb is not None, "lamb needed when updating"
                newd = torch.tensor([])
            for i, data in enumerate(dataloader):
                inps, labels = data
                inps.requires_grad_(True)
                outs = net(inps)
                loss = criterion(outs, labels)
    
                if update:
                    new_samples = self.newdata(outs, inps, labels, lamb)
                    newd = torch.cat((newd, new_samples))
    
                opti.zerograd()
                loss.backward()
                opti.step()
            if update:
                return newd
    
    
        def __call__(self, substitute, data, lamb):
            N = len(data)
            opti = self.opti(substitute.parameters())
            for i, item in enumerate(data):
                dataloader = self.dataloader(data)
                if i is N-1:
                    return self.train(substitute, self.criterion,
                                      opti, dataloader, lamb, True)
                else:
                    self.train(substitute, self.criterion,
                               opti, dataloader)
    
    
    
    
    
    def quireone(func): #a decorator, for easy to define optimizer
        def wrapper1(*args, **kwargs):
            def wrapper2(arg):
                result = func(arg, *args, **kwargs)
                return result
            wrapper2.__doc__ = func.__doc__
            wrapper2.__name__ = func.__name__
            return wrapper2
        return wrapper1
    
    
    
    
    
    
    
    """
    Adversarial Examples Improve Image Recognition
    """
    
    
    import torch
    import torch.nn as nn
    
    
    class Mixturenorm1d(nn.Module):
    
        def __init__(self, rel, num_features:int, *args, **kwargs):
            super(Mixturenorm1d, self).__init__()
            self.norm1 = nn.BatchNorm1d(num_features,
                                        *args, **kwargs)
            self.norm2 = nn.BatchNorm1d(num_features,
                                        *args, **kwargs)
            self.rel = rel
    
        def forward(self, x):
            if self.rel.adv and self.rel.training:
                return self.norm2(x)
            else:
                return self.norm1(x)
    
        def __setattr__(self, name, value):
            """
            we should redefine the setattr method,
            or self.rel will be regard as a child of Mixturenorm.
            Hence, if we call instance.modules() or instance.children(),
            RecursionError: maximum recursion depth exceeded will be raised.
            :param name:
            :param value:
            :return:
            """
            if name is "rel":
                object.__setattr__(self, name, value)
            else:
                super(Mixturenorm1d, self).__setattr__(name, value)
    
    
    
    class Mixturenorm2d(nn.Module):
    
        def __init__(self, rel, num_features:int, *args, **kwargs):
            super(Mixturenorm2d, self).__init__()
            self.norm1 = nn.BatchNorm2d(num_features,
                                        *args, **kwargs)
            self.norm2 = nn.BatchNorm2d(num_features,
                                        *args, **kwargs)
            self.rel = rel
        def forward(self, x):
            if self.rel.adv and self.rel.training:
                return self.norm2(x)
            else:
                return self.norm1(x)
    
        def __setattr__(self, name, value):
            if name is "rel":
                object.__setattr__(self, name, value)
            else:
                super(Mixturenorm2d, self).__setattr__(name, value)
    
    
    
    if __name__ == "__main__":
    
        class Testnet(nn.Module):
    
            def __init__(self):
                super(Testnet, self).__init__()
                self.flag = False
                self.dense = nn.Sequential(
                    nn.Linear(10, 20),
                    Mixturenorm1d(self, 20),
                    nn.ReLU(),
                    nn.Linear(20, 1)
                )
    
            def forward(self, x, adv=False):
                self.adv = adv
                return self.dense(x)
    
    
        x = torch.rand((3, 10))
        test = Testnet()
        out = test(x, True)
    
    
  • 相关阅读:
    解题:AHOI 2005 航线规划
    解题:SCOI 2008 天平
    解题:SCOI 2014 方伯伯运椰子
    解题:APIO 2008 免费道路
    解题:USACO15JAN Grass Cownoisseur
    669. 换硬币(dp动态规划)
    8. 旋转字符串
    147. 水仙花数
    1131. 排列中的函数
    78. 最长公共前缀
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/12469332.html
Copyright © 2011-2022 走看看