zoukankan      html  css  js  c++  java
  • How to get gradients with respect to the inputs in pytorch

    This is one way to find adversarial examples of CNN.

    The boilerplate:

    import torch
    from torch.autograd import Variable
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    

      Define a simple network:

    class lolnet(nn.Module):
        def __init__(self):
            super(lolnet,self).__init__()
            self.a=nn.Linear(in_features=1,out_features=1,bias=False)
            self.a.weight = nn.Parameter(torch.FloatTensor([[0.6]]))
            self.b=nn.Linear(in_features=1,out_features=1,bias=False)
            self.b.weight=nn.Parameter(torch.FloatTensor([[0.6]]))
            
        def forward(self, inputs):
            return self.b(
                self.a(inputs)
            )
    

      The inputs

    inputs=np.array([[5]])
    inputs=torch.from_numpy(inputs).float()
    inputs=Variable(inputs)
    inputs.requires_grad=True
    net=lolnet()
    

      The optimizer

    opx=optim.SGD(
        params=[
            {"params":inputs}
        ],lr=0.5
    )
    

      The optimization process

    for i in range(50):
        x=net(inputs)
        loss=(x-1)**2
        opx.zero_grad() 
        loss.backward()
        opx.step()
        print(net.a.weight.data.numpy()[0][0],inputs.data.numpy()[0][0],loss.data.numpy()[0][0])
    

      The results are as below:

    0.6 4.712 0.6400001
    0.6 4.4613247 0.4848616
    0.6 4.243137 0.36732942
    0.6 4.0532265 0.27828723
    0.6 3.8879282 0.2108294
    0.6 3.7440526 0.15972354
    0.6 3.6188233 0.1210059
    0.6 3.5098238 0.09167358
    0.6 3.4149506 0.069451585
    0.6 3.332373 0.052616227
    0.6 3.2604973 0.039861854
    0.6 3.1979368 0.030199187
    0.6 3.143484 0.022878764
    0.6 3.0960886 0.017332876
    0.6 3.0548356 0.013131317
    0.6 3.0189288 0.00994824
    0.6 2.9876754 0.0075367615
    0.6 2.9604726 0.005709796
    0.6 2.9367952 0.0043257284
    0.6 2.9161866 0.003277142
    0.6 2.8982487 0.0024827516
    0.6 2.8826356 0.0018809267
    0.6 2.869046 0.001424982
    0.6 2.8572176 0.0010795629
    0.6 2.8469222 0.0008178701
    0.6 2.837961 0.00061961624
    0.6 2.830161 0.00046941772
    0.6 2.8233721 0.000355627
    0.6 2.8174632 0.0002694209
    0.6 2.81232 0.00020411481
    0.6 2.8078432 0.0001546371
    0.6 2.8039467 0.00011715048
    0.6 2.8005552 8.875507e-05
    0.6 2.7976031 6.724081e-05
    0.6 2.7950337 5.093933e-05
    0.6 2.7927973 3.8591857e-05
    0.6 2.7908509 2.9236677e-05
    0.6 2.7891567 2.2150038e-05
    0.6 2.7876818 1.6781378e-05
    0.6 2.7863982 1.2713146e-05
    0.6 2.785281 9.631679e-06
    0.6 2.7843084 7.296927e-06
    0.6 2.783462 5.527976e-06
    0.6 2.7827253 4.1880226e-06
    0.6 2.782084 3.1727632e-06
    0.6 2.7815259 2.4034823e-06
    0.6 2.78104 1.821013e-06
    0.6 2.7806172 1.3793326e-06
    0.6 2.780249 1.044933e-06
    0.6 2.7799287 7.9170513e-07
    
    Process finished with exit code 0
    

      

  • 相关阅读:
    5 输出的properties文件按照key进行排序
    JFinal 部署在 Tomcat 下推荐方法(转)
    15个必须知道的chrome开发者技巧(转)
    一次完整的浏览器请求流程(转)
    工作第一天
    Struts2的crud
    equal 和 ==
    生成Apk遇到的问题
    Http Framework
    Gradle: The New Android Build System
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8974640.html
Copyright © 2011-2022 走看看