zoukankan      html  css  js  c++  java
  • 3.神经网络的保存、神经网络提取的2 ways

     1 """
     2 torch: 0.4
     3 matplotlib
     4 神经网络的保存 
     5 神经网络提取的2 ways
     6 """
     7 import torch
     8 import matplotlib.pyplot as plt
     9 
    10 
    11 
    12 # fake data
    13 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
    14 y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
    15 
    16 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
    17 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
    18 
    19 # save net1
    20 def save():
    21     # 建立网络实例net1
    22     net1 = torch.nn.Sequential(
    23         torch.nn.Linear(1, 10),
    24         torch.nn.ReLU(),
    25         torch.nn.Linear(10, 1)
    26     )
    27     optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    28     loss_func = torch.nn.MSELoss()
    29     #训练
    30     for t in range(100):
    31         prediction = net1(x)
    32         loss = loss_func(prediction, y)
    33         optimizer.zero_grad()
    34         loss.backward()
    35         optimizer.step()
    36 
    37     # plot result
    38     plt.figure(1, figsize=(10, 3))
    39     plt.subplot(131)
    40     plt.title('Net1')
    41     plt.scatter(x.data.numpy(), y.data.numpy())
    42     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    43 
    44     # 2 ways to save the net
    45     torch.save(net1, 'net.pkl')  # save entire net
    46     torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters
    47 
    48 
    49 def restore_net():
    50     # restore entire net1 to net2
    51     net2 = torch.load('net.pkl')
    52     prediction = net2(x)
    53 
    54     # plot result
    55     plt.subplot(132)
    56     plt.title('Net2')
    57     plt.scatter(x.data.numpy(), y.data.numpy())
    58     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    59 
    60 
    61 def restore_params():
    62     # restore only the parameters in net1 to net3
    63     net3 = torch.nn.Sequential(
    64         torch.nn.Linear(1, 10),
    65         torch.nn.ReLU(),
    66         torch.nn.Linear(10, 1)
    67     )
    68 
    69     # copy net1's parameters into net3
    70     net3.load_state_dict(torch.load('net_params.pkl'))
    71     prediction = net3(x)
    72 
    73     # plot result
    74     plt.subplot(133)
    75     plt.title('Net3')
    76     plt.scatter(x.data.numpy(), y.data.numpy())
    77     plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    78     plt.show()
    79 
    80 # save net1
    81 save()
    82 
    83 # restore entire net (may slow)
    84 restore_net()
    85 
    86 # restore only the net parameters
    87 restore_params()
  • 相关阅读:
    Working with macro signatures
    Reset and Clear Recent Items and Frequent Places in Windows 10
    git分支演示
    The current .NET SDK does not support targeting .NET Core 2.1. Either target .NET Core 2.0 or lower, or use a version of the .NET SDK that supports .NET Core 2.1.
    Build website project by roslyn through devenv.com
    Configure environment variables for different tools in jenkins
    NUnit Console Command Line
    Code Coverage and Unit Test in SonarQube
    头脑王者 物理化学生物
    头脑王者 常识,饮食
  • 原文地址:https://www.cnblogs.com/xuechengmeigui/p/12388514.html
Copyright © 2011-2022 走看看