zoukankan      html  css  js  c++  java
  • pytorch之 optimizer comparison

     1 import torch
     2 import torch.utils.data as Data
     3 import torch.nn.functional as F
     4 import matplotlib.pyplot as plt
     5 import torch.optim
     6 # torch.manual_seed(1)    # reproducible
     7 
     8 LR = 0.01
     9 BATCH_SIZE = 32
    10 EPOCH = 12
    11 
    12 # fake dataset
    13 x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
    14 y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))
    15 
    16 # plot dataset
    17 plt.scatter(x.numpy(), y.numpy())
    18 plt.show()
    19 
    20 # put dateset into torch dataset
    21 torch_dataset = Data.TensorDataset(x, y)
    22 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)
    23 
    24 
    25 # default network
    26 class Net(torch.nn.Module):
    27     def __init__(self):
    28         super(Net, self).__init__()
    29         self.hidden = torch.nn.Linear(1, 20)   # hidden layer
    30         self.predict = torch.nn.Linear(20, 1)   # output layer
    31 
    32     def forward(self, x):
    33         x = F.relu(self.hidden(x))      # activation function for hidden layer
    34         x = self.predict(x)             # linear output
    35         return x
    36 
    37 if __name__ == '__main__':
    38     # different nets
    39     net_SGD         = Net()
    40     net_Momentum    = Net()
    41     net_RMSprop     = Net()
    42     net_Adam        = Net()
    43     nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]
    44 
    45     # different optimizers
    46     opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)
    47     opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
    48     opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
    49     opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
    50     optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]
    51 
    52     loss_func = torch.nn.MSELoss()
    53     losses_his = [[], [], [], []]   # record loss
    54 
    55     # training
    56     for epoch in range(EPOCH):
    57         print('Epoch: ', epoch)
    58         for step, (b_x, b_y) in enumerate(loader):          # for each training step
    59             for net, opt, l_his in zip(nets, optimizers, losses_his):
    60                 output = net(b_x)              # get output for every net
    61                 loss = loss_func(output, b_y)  # compute loss for every net
    62                 opt.zero_grad()                # clear gradients for next train
    63                 loss.backward()                # backpropagation, compute gradients
    64                 opt.step()                     # apply gradients
    65                 l_his.append(loss.data.numpy())     # loss recoder
    66 
    67     labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
    68     for i, l_his in enumerate(losses_his):
    69         plt.plot(l_his, label=labels[i])
    70     plt.legend(loc='best')
    71     plt.xlabel('Steps')
    72     plt.ylabel('Loss')
    73     plt.ylim((0, 0.2))
    74     plt.show()
  • 相关阅读:
    hdu 2647 Reward
    hdu 2094 产生冠军
    hdu 3342 Legal or Not
    hdu 1285 确定比赛名次
    hdu 3006 The Number of set
    hdu 1429 胜利大逃亡(续)
    UVA 146 ID Codes
    UVA 131 The Psychic Poker Player
    洛谷 P2491消防 解题报告
    洛谷 P2587 [ZJOI2008]泡泡堂 解题报告
  • 原文地址:https://www.cnblogs.com/dhName/p/11743220.html
Copyright © 2011-2022 走看看