zoukankan      html  css  js  c++  java
  • 3.神经网络的保存、神经网络提取的2 ways

     1 """
     2 torch: 0.4
     3 matplotlib
     4 神经网络的保存 
     5 神经网络提取的2 ways
     6 """
     7 import torch
     8 import matplotlib.pyplot as plt
     9 
    10 
    11 
    12 # fake data
    13 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
    14 y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
    15 
    16 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
    17 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
    18 
    19 # save net1
    20 def save():
    21     # 建立网络实例net1
    22     net1 = torch.nn.Sequential(
    23         torch.nn.Linear(1, 10),
    24         torch.nn.ReLU(),
    25         torch.nn.Linear(10, 1)
    26     )
    27     optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    28     loss_func = torch.nn.MSELoss()
    29     #训练
    30     for t in range(100):
    31         prediction = net1(x)
    32         loss = loss_func(prediction, y)
    33         optimizer.zero_grad()
    34         loss.backward()
    35         optimizer.step()
    36 
    37     # plot result
    38     plt.figure(1, figsize=(10, 3))
    39     plt.subplot(131)
    40     plt.title('Net1')
    41     plt.scatter(x.data.numpy(), y.data.numpy())
    42     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    43 
    44     # 2 ways to save the net
    45     torch.save(net1, 'net.pkl')  # save entire net
    46     torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters
    47 
    48 
    49 def restore_net():
    50     # restore entire net1 to net2
    51     net2 = torch.load('net.pkl')
    52     prediction = net2(x)
    53 
    54     # plot result
    55     plt.subplot(132)
    56     plt.title('Net2')
    57     plt.scatter(x.data.numpy(), y.data.numpy())
    58     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    59 
    60 
    61 def restore_params():
    62     # restore only the parameters in net1 to net3
    63     net3 = torch.nn.Sequential(
    64         torch.nn.Linear(1, 10),
    65         torch.nn.ReLU(),
    66         torch.nn.Linear(10, 1)
    67     )
    68 
    69     # copy net1's parameters into net3
    70     net3.load_state_dict(torch.load('net_params.pkl'))
    71     prediction = net3(x)
    72 
    73     # plot result
    74     plt.subplot(133)
    75     plt.title('Net3')
    76     plt.scatter(x.data.numpy(), y.data.numpy())
    77     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    78     plt.show()
    79 
    80 # save net1
    81 save()
    82 
    83 # restore entire net (may slow)
    84 restore_net()
    85 
    86 # restore only the net parameters
    87 restore_params()
  • 相关阅读:
    Oracle建立表空间和用户
    fscanf()函数具体解释
    三层架构(我的理解及具体分析)
    ListView嵌套ListView优化
    Android xml 解析
    玩转Web之servlet(三)---一张图看懂B/S架构
    jquery.scrollTo-min.js
    C#中MessageBox使用方法大全(附效果图)
    hdu 1882 Strange Billboard(位运算+枚举)
    MySQL 通配符学习小结
  • 原文地址:https://www.cnblogs.com/xuechengmeigui/p/12388514.html
Copyright © 2011-2022 走看看