zoukankan      html  css  js  c++  java
  • How to get gradients with respect to the inputs in pytorch

    This is one way to find adversarial examples of CNN.

    The boilerplate:

    import torch
    from torch.autograd import Variable
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    

      Define a simple network:

    class lolnet(nn.Module):
        def __init__(self):
            super(lolnet,self).__init__()
            self.a=nn.Linear(in_features=1,out_features=1,bias=False)
            self.a.weight = nn.Parameter(torch.FloatTensor([[0.6]]))
            self.b=nn.Linear(in_features=1,out_features=1,bias=False)
            self.b.weight=nn.Parameter(torch.FloatTensor([[0.6]]))
            
        def forward(self, inputs):
            return self.b(
                self.a(inputs)
            )
    

      The inputs

    inputs=np.array([[5]])
    inputs=torch.from_numpy(inputs).float()
    inputs=Variable(inputs)
    inputs.requires_grad=True
    net=lolnet()
    

      The optimizer

    opx=optim.SGD(
        params=[
            {"params":inputs}
        ],lr=0.5
    )
    

      The optimization process

    for i in range(50):
        x=net(inputs)
        loss=(x-1)**2
        opx.zero_grad() 
        loss.backward()
        opx.step()
        print(net.a.weight.data.numpy()[0][0],inputs.data.numpy()[0][0],loss.data.numpy()[0][0])
    

      The results are as below:

    0.6 4.712 0.6400001
    0.6 4.4613247 0.4848616
    0.6 4.243137 0.36732942
    0.6 4.0532265 0.27828723
    0.6 3.8879282 0.2108294
    0.6 3.7440526 0.15972354
    0.6 3.6188233 0.1210059
    0.6 3.5098238 0.09167358
    0.6 3.4149506 0.069451585
    0.6 3.332373 0.052616227
    0.6 3.2604973 0.039861854
    0.6 3.1979368 0.030199187
    0.6 3.143484 0.022878764
    0.6 3.0960886 0.017332876
    0.6 3.0548356 0.013131317
    0.6 3.0189288 0.00994824
    0.6 2.9876754 0.0075367615
    0.6 2.9604726 0.005709796
    0.6 2.9367952 0.0043257284
    0.6 2.9161866 0.003277142
    0.6 2.8982487 0.0024827516
    0.6 2.8826356 0.0018809267
    0.6 2.869046 0.001424982
    0.6 2.8572176 0.0010795629
    0.6 2.8469222 0.0008178701
    0.6 2.837961 0.00061961624
    0.6 2.830161 0.00046941772
    0.6 2.8233721 0.000355627
    0.6 2.8174632 0.0002694209
    0.6 2.81232 0.00020411481
    0.6 2.8078432 0.0001546371
    0.6 2.8039467 0.00011715048
    0.6 2.8005552 8.875507e-05
    0.6 2.7976031 6.724081e-05
    0.6 2.7950337 5.093933e-05
    0.6 2.7927973 3.8591857e-05
    0.6 2.7908509 2.9236677e-05
    0.6 2.7891567 2.2150038e-05
    0.6 2.7876818 1.6781378e-05
    0.6 2.7863982 1.2713146e-05
    0.6 2.785281 9.631679e-06
    0.6 2.7843084 7.296927e-06
    0.6 2.783462 5.527976e-06
    0.6 2.7827253 4.1880226e-06
    0.6 2.782084 3.1727632e-06
    0.6 2.7815259 2.4034823e-06
    0.6 2.78104 1.821013e-06
    0.6 2.7806172 1.3793326e-06
    0.6 2.780249 1.044933e-06
    0.6 2.7799287 7.9170513e-07
    
    Process finished with exit code 0
    

      

  • 相关阅读:
    new delete的内部实现代码
    子串的替换
    求字符串的长度
    TSQL语句学习(四)
    TSQL语句学习(二)
    杭电acm1036
    杭电acm2032
    杭电acm2045
    杭电acm2072
    杭电acm1029
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8974640.html
Copyright © 2011-2022 走看看