zoukankan      html  css  js  c++  java
  • How to get gradients with respect to the inputs in pytorch

    This is one way to find adversarial examples of CNN.

    The boilerplate:

    import torch
    from torch.autograd import Variable
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    

      Define a simple network:

    class lolnet(nn.Module):
        def __init__(self):
            super(lolnet,self).__init__()
            self.a=nn.Linear(in_features=1,out_features=1,bias=False)
            self.a.weight = nn.Parameter(torch.FloatTensor([[0.6]]))
            self.b=nn.Linear(in_features=1,out_features=1,bias=False)
            self.b.weight=nn.Parameter(torch.FloatTensor([[0.6]]))
            
        def forward(self, inputs):
            return self.b(
                self.a(inputs)
            )
    

      The inputs

    inputs=np.array([[5]])
    inputs=torch.from_numpy(inputs).float()
    inputs=Variable(inputs)
    inputs.requires_grad=True
    net=lolnet()
    

      The optimizer

    opx=optim.SGD(
        params=[
            {"params":inputs}
        ],lr=0.5
    )
    

      The optimization process

    for i in range(50):
        x=net(inputs)
        loss=(x-1)**2
        opx.zero_grad() 
        loss.backward()
        opx.step()
        print(net.a.weight.data.numpy()[0][0],inputs.data.numpy()[0][0],loss.data.numpy()[0][0])
    

      The results are as below:

    0.6 4.712 0.6400001
    0.6 4.4613247 0.4848616
    0.6 4.243137 0.36732942
    0.6 4.0532265 0.27828723
    0.6 3.8879282 0.2108294
    0.6 3.7440526 0.15972354
    0.6 3.6188233 0.1210059
    0.6 3.5098238 0.09167358
    0.6 3.4149506 0.069451585
    0.6 3.332373 0.052616227
    0.6 3.2604973 0.039861854
    0.6 3.1979368 0.030199187
    0.6 3.143484 0.022878764
    0.6 3.0960886 0.017332876
    0.6 3.0548356 0.013131317
    0.6 3.0189288 0.00994824
    0.6 2.9876754 0.0075367615
    0.6 2.9604726 0.005709796
    0.6 2.9367952 0.0043257284
    0.6 2.9161866 0.003277142
    0.6 2.8982487 0.0024827516
    0.6 2.8826356 0.0018809267
    0.6 2.869046 0.001424982
    0.6 2.8572176 0.0010795629
    0.6 2.8469222 0.0008178701
    0.6 2.837961 0.00061961624
    0.6 2.830161 0.00046941772
    0.6 2.8233721 0.000355627
    0.6 2.8174632 0.0002694209
    0.6 2.81232 0.00020411481
    0.6 2.8078432 0.0001546371
    0.6 2.8039467 0.00011715048
    0.6 2.8005552 8.875507e-05
    0.6 2.7976031 6.724081e-05
    0.6 2.7950337 5.093933e-05
    0.6 2.7927973 3.8591857e-05
    0.6 2.7908509 2.9236677e-05
    0.6 2.7891567 2.2150038e-05
    0.6 2.7876818 1.6781378e-05
    0.6 2.7863982 1.2713146e-05
    0.6 2.785281 9.631679e-06
    0.6 2.7843084 7.296927e-06
    0.6 2.783462 5.527976e-06
    0.6 2.7827253 4.1880226e-06
    0.6 2.782084 3.1727632e-06
    0.6 2.7815259 2.4034823e-06
    0.6 2.78104 1.821013e-06
    0.6 2.7806172 1.3793326e-06
    0.6 2.780249 1.044933e-06
    0.6 2.7799287 7.9170513e-07
    
    Process finished with exit code 0
    

      

  • 相关阅读:
    local 不能解析为127.0.0.1
    完全使用接口方式调用WCF 服务
    【人生】自己对于求职应聘的一些感受
    OO的经典例子
    剪刀、石头、布机器人比赛
    TextTree 文本资料收集轻量级工具
    两个代替重复输入的小工具
    桌面助手 Desktop Helper 自动帮你关闭指定的窗口
    磁盘可用空间平衡
    用C#制造可以继承的“枚举”
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8974640.html
Copyright © 2011-2022 走看看