zoukankan      html  css  js  c++  java
  • pytorch(二十二):正则化

    一、实例

     

     

    二、代码

      1 import  torch
      2 import  torch.nn as nn
      3 import  torch.nn.functional as F
      4 import  torch.optim as optim
      5 from    torchvision import datasets, transforms
      6 
      7 from visdom import Visdom
      8 
      9 batch_size=200
     10 learning_rate=0.01
     11 epochs=10
     12 
     13 train_loader = torch.utils.data.DataLoader(
     14     datasets.MNIST('datasets/mnist_data', train=True, download=True,
     15                    transform=transforms.Compose([
     16                        transforms.ToTensor(),
     17                        # transforms.Normalize((0.1307,), (0.3081,))
     18                    ])),
     19     batch_size=batch_size, shuffle=True)
     20 test_loader = torch.utils.data.DataLoader(
     21     datasets.MNIST('datasets/mnist_data/', train=False, transform=transforms.Compose([
     22         transforms.ToTensor(),
     23         # transforms.Normalize((0.1307,), (0.3081,))
     24     ])),
     25     batch_size=batch_size, shuffle=True)
     26 
     27 
     28 
     29 class MLP(nn.Module):
     30 
     31     def __init__(self):
     32         super(MLP, self).__init__()
     33 
     34         self.model = nn.Sequential(
     35             nn.Linear(784, 200),
     36             nn.LeakyReLU(inplace=True),
     37             nn.Linear(200, 200),
     38             nn.LeakyReLU(inplace=True),
     39             nn.Linear(200, 10),
     40             nn.LeakyReLU(inplace=True),
     41         )
     42 
     43     def forward(self, x):
     44         x = self.model(x)
     45 
     46         return x
     47 
     48 device = torch.device('cuda:0')
     49 net = MLP().to(device)
     50 optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
     51 criteon = nn.CrossEntropyLoss().to(device)
     52 
     53 viz = Visdom()
     54 
     55 viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
     56 viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',
     57                                                    legend=['loss', 'acc.']))
     58 global_step = 0
     59 
     60 for epoch in range(epochs):
     61 
     62     for batch_idx, (data, target) in enumerate(train_loader):
     63         data = data.view(-1, 28*28)
     64         data, target = data.to(device), target.cuda()
     65 
     66         logits = net(data)
     67         loss = criteon(logits, target)
     68 
     69         optimizer.zero_grad()
     70         loss.backward()
     71         # print(w1.grad.norm(), w2.grad.norm())
     72         optimizer.step()
     73 
     74         global_step += 1
     75         viz.line([loss.item()], [global_step], win='train_loss', update='append')
     76 
     77         if batch_idx % 100 == 0:
     78             print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
     79                 epoch, batch_idx * len(data), len(train_loader.dataset),
     80                        100. * batch_idx / len(train_loader), loss.item()))
     81 
     82 
     83     test_loss = 0
     84     correct = 0
     85     for data, target in test_loader:
     86         data = data.view(-1, 28 * 28)
     87         data, target = data.to(device), target.cuda()
     88         logits = net(data)
     89         test_loss += criteon(logits, target).item()
     90 
     91         pred = logits.argmax(dim=1)
     92         correct += pred.eq(target).float().sum().item()
     93 
     94     viz.line([[test_loss, correct / len(test_loader.dataset)]],
     95              [global_step], win='test', update='append')
     96     viz.images(data.view(-1, 1, 28, 28), win='x')
     97     viz.text(str(pred.detach().cpu().numpy()), win='pred',
     98              opts=dict(title='pred'))
     99 
    100     test_loss /= len(test_loader.dataset)
    101     print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
    102         test_loss, correct, len(test_loader.dataset),
    103         100. * correct / len(test_loader.dataset)))
  • 相关阅读:
    将PHP文件生成静态文件源码
    Entity Framework Code First 学习日记(6)一对多关系
    Entity Framework Code First 学习日记(5)
    Entity Framework Code First 学习日记(3)
    Entity Framework Code First 学习日记(7)多对多关系
    Entity Framework Code First学习日记(2)
    Entity Framework Code First 学习日记(8)一对一关系
    Entity Framework Code First 学习日记(9)映射继承关系
    Entity Framework Code First 学习日记(10)兼容遗留数据库
    Entity Framework Code First 学习日记(4)
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14060728.html
Copyright © 2011-2022 走看看