zoukankan      html  css  js  c++  java
  • 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 

     莫烦的对各种优化通俗理解的视频

      1 import torch
      2 
      3 import torch.utils.data as Data
      4 
      5 import torch.nn.functional as F
      6 
      7 from torch.autograd import Variable
      8 
      9 import matplotlib.pyplot as plt
     10 
     11  
     12 
     13 # 超参数
     14 
     15 LR = 0.01
     16 
     17 BATCH_SIZE = 32
     18 
     19 EPOCH = 12
     20 
     21  
     22 
     23 # 生成假数据
     24 
     25 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
     26 
     27 x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)  # x data (tensor), shape(100, 1)
     28 
     29 # 0.2 * torch.rand(x.size())增加噪点
     30 
     31 y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))
     32 
     33  
     34 
     35 # 输出数据图
     36 
     37 # plt.scatter(x.numpy(), y.numpy())
     38 
     39 # plt.show()
     40 
     41  
     42 
     43 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
     44 
     45 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
     46 
     47  
     48 
     49  
     50 
     51 class Net(torch.nn.Module):
     52 
     53     # 初始化
     54 
     55     def __init__(self):
     56 
     57         super(Net, self).__init__()
     58 
     59         self.hidden = torch.nn.Linear(1, 20)
     60 
     61         self.predict = torch.nn.Linear(20, 1)
     62 
     63  
     64 
     65     # 前向传递
     66 
     67     def forward(self, x):
     68 
     69         x = F.relu(self.hidden(x))
     70 
     71         x = self.predict(x)
     72 
     73         return x
     74 
     75  
     76 
     77 net_SGD = Net()
     78 
     79 net_Momentum = Net()
     80 
     81 net_RMSProp = Net()
     82 
     83 net_Adam = Net()
     84 
     85  
     86 
     87 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]
     88 
     89  
     90 
     91 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
     92 
     93 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
     94 
     95 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)
     96 
     97 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
     98 
     99 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]
    100 
    101  
    102 
    103 loss_func = torch.nn.MSELoss()
    104 
    105  
    106 
    107 loss_his = [[], [], [], []]  # 记录损失
    108 
    109  
    110 
    111 for epoch in range(EPOCH):
    112 
    113     print(epoch)
    114 
    115     for step, (batch_x, batch_y) in enumerate(loader):
    116 
    117         b_x = Variable(batch_x)
    118 
    119         b_y = Variable(batch_y)
    120 
    121  
    122 
    123         for net, opt,l_his in zip(nets, optimizers, loss_his):
    124 
    125             output = net(b_x)  # get output for every net
    126 
    127             loss = loss_func(output, b_y)  # compute loss for every net
    128 
    129             opt.zero_grad()  # clear gradients for next train
    130 
    131             loss.backward()  # backpropagation, compute gradients
    132 
    133             opt.step()  # apply gradients
    134 
    135             l_his.append(loss.data.numpy())  # loss recoder
    136 
    137 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
    138 
    139 for i, l_his in enumerate(loss_his):
    140 
    141     plt.plot(l_his, label=labels[i])
    142 
    143 plt.legend(loc='best')
    144 
    145 plt.xlabel('Steps')
    146 
    147 plt.ylabel('Loss')
    148 
    149 plt.ylim((0, 0.2))
    150 
    151 plt.show()
    152 
    153         
    154 
    155  
    156 
    157  

  • 相关阅读:
    flink 读取kafka 数据,partition分配
    Flink 报错 "Could not find a suitable table factory for 'org.apache.flink.table.factories.StreamTableSourceFactory' in the classpath"
    flume接收http请求,并将数据写到kafka
    【翻译】Flume 1.8.0 User Guide(用户指南) Processors
    【翻译】Flume 1.8.0 User Guide(用户指南) Channel
    【翻译】Flume 1.8.0 User Guide(用户指南) Sink
    【翻译】Flume 1.8.0 User Guide(用户指南) source
    【翻译】Flume 1.8.0 User Guide(用户指南)
    Apache Flink 简单安装
    Java之使用IDE
  • 原文地址:https://www.cnblogs.com/henuliulei/p/11397963.html
Copyright © 2011-2022 走看看