1 """ 2 torch: 0.4 3 matplotlib 4 神经网络的保存 5 神经网络提取的2 ways 6 """ 7 import torch 8 import matplotlib.pyplot as plt 9 10 11 12 # fake data 13 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) 14 y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) 15 16 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors 17 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 18 19 # save net1 20 def save(): 21 # 建立网络实例net1 22 net1 = torch.nn.Sequential( 23 torch.nn.Linear(1, 10), 24 torch.nn.ReLU(), 25 torch.nn.Linear(10, 1) 26 ) 27 optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 28 loss_func = torch.nn.MSELoss() 29 #训练 30 for t in range(100): 31 prediction = net1(x) 32 loss = loss_func(prediction, y) 33 optimizer.zero_grad() 34 loss.backward() 35 optimizer.step() 36 37 # plot result 38 plt.figure(1, figsize=(10, 3)) 39 plt.subplot(131) 40 plt.title('Net1') 41 plt.scatter(x.data.numpy(), y.data.numpy()) 42 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 43 44 # 2 ways to save the net 45 torch.save(net1, 'net.pkl') # save entire net 46 torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters 47 48 49 def restore_net(): 50 # restore entire net1 to net2 51 net2 = torch.load('net.pkl') 52 prediction = net2(x) 53 54 # plot result 55 plt.subplot(132) 56 plt.title('Net2') 57 plt.scatter(x.data.numpy(), y.data.numpy()) 58 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 59 60 61 def restore_params(): 62 # restore only the parameters in net1 to net3 63 net3 = torch.nn.Sequential( 64 torch.nn.Linear(1, 10), 65 torch.nn.ReLU(), 66 torch.nn.Linear(10, 1) 67 ) 68 69 # copy net1's parameters into net3 70 net3.load_state_dict(torch.load('net_params.pkl')) 71 prediction = net3(x) 72 73 # plot result 74 plt.subplot(133) 75 plt.title('Net3') 76 plt.scatter(x.data.numpy(), y.data.numpy()) 77 plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 78 plt.show() 79 80 # save net1 81 save() 82 83 # restore entire net (may slow) 84 restore_net() 85 86 # restore only the net parameters 87 restore_params()