zoukankan      html  css  js  c++  java
  • 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 

     莫烦的对各种优化通俗理解的视频

      1 import torch
      2 
      3 import torch.utils.data as Data
      4 
      5 import torch.nn.functional as F
      6 
      7 from torch.autograd import Variable
      8 
      9 import matplotlib.pyplot as plt
     10 
     11  
     12 
     13 # 超参数
     14 
     15 LR = 0.01
     16 
     17 BATCH_SIZE = 32
     18 
     19 EPOCH = 12
     20 
     21  
     22 
     23 # 生成假数据
     24 
     25 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
     26 
     27 x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)  # x data (tensor), shape(100, 1)
     28 
     29 # 0.2 * torch.rand(x.size())增加噪点
     30 
     31 y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))
     32 
     33  
     34 
     35 # 输出数据图
     36 
     37 # plt.scatter(x.numpy(), y.numpy())
     38 
     39 # plt.show()
     40 
     41  
     42 
     43 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
     44 
     45 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
     46 
     47  
     48 
     49  
     50 
     51 class Net(torch.nn.Module):
     52 
     53     # 初始化
     54 
     55     def __init__(self):
     56 
     57         super(Net, self).__init__()
     58 
     59         self.hidden = torch.nn.Linear(1, 20)
     60 
     61         self.predict = torch.nn.Linear(20, 1)
     62 
     63  
     64 
     65     # 前向传递
     66 
     67     def forward(self, x):
     68 
     69         x = F.relu(self.hidden(x))
     70 
     71         x = self.predict(x)
     72 
     73         return x
     74 
     75  
     76 
     77 net_SGD = Net()
     78 
     79 net_Momentum = Net()
     80 
     81 net_RMSProp = Net()
     82 
     83 net_Adam = Net()
     84 
     85  
     86 
     87 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]
     88 
     89  
     90 
     91 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
     92 
     93 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
     94 
     95 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)
     96 
     97 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
     98 
     99 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]
    100 
    101  
    102 
    103 loss_func = torch.nn.MSELoss()
    104 
    105  
    106 
    107 loss_his = [[], [], [], []]  # 记录损失
    108 
    109  
    110 
    111 for epoch in range(EPOCH):
    112 
    113     print(epoch)
    114 
    115     for step, (batch_x, batch_y) in enumerate(loader):
    116 
    117         b_x = Variable(batch_x)
    118 
    119         b_y = Variable(batch_y)
    120 
    121  
    122 
    123         for net, opt,l_his in zip(nets, optimizers, loss_his):
    124 
    125             output = net(b_x)  # get output for every net
    126 
    127             loss = loss_func(output, b_y)  # compute loss for every net
    128 
    129             opt.zero_grad()  # clear gradients for next train
    130 
    131             loss.backward()  # backpropagation, compute gradients
    132 
    133             opt.step()  # apply gradients
    134 
    135             l_his.append(loss.data.numpy())  # loss recoder
    136 
    137 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
    138 
    139 for i, l_his in enumerate(loss_his):
    140 
    141     plt.plot(l_his, label=labels[i])
    142 
    143 plt.legend(loc='best')
    144 
    145 plt.xlabel('Steps')
    146 
    147 plt.ylabel('Loss')
    148 
    149 plt.ylim((0, 0.2))
    150 
    151 plt.show()
    152 
    153         
    154 
    155  
    156 
    157  

  • 相关阅读:
    python3 进程间的通信(管道)Pipe
    python3 进程间的通信(队列)Queue
    python3 队列的简单用法Queue
    python3 进程锁Lock(模拟抢票)
    python3 守护进程daemon
    python3 僵尸进程
    python3 process中的name和pid
    python3 Process中的terminate和is_alive
    python3 通过多进程来实现一下同时和多个客户端进行连接通信
    python3 进程之间数据是隔离的
  • 原文地址:https://www.cnblogs.com/henuliulei/p/11397963.html
Copyright © 2011-2022 走看看