zoukankan      html  css  js  c++  java
  • pytorch(二十二):正则化

    一、实例

     

     

    二、代码

      1 import  torch
      2 import  torch.nn as nn
      3 import  torch.nn.functional as F
      4 import  torch.optim as optim
      5 from    torchvision import datasets, transforms
      6 
      7 from visdom import Visdom
      8 
      9 batch_size=200
     10 learning_rate=0.01
     11 epochs=10
     12 
     13 train_loader = torch.utils.data.DataLoader(
     14     datasets.MNIST('datasets/mnist_data', train=True, download=True,
     15                    transform=transforms.Compose([
     16                        transforms.ToTensor(),
     17                        # transforms.Normalize((0.1307,), (0.3081,))
     18                    ])),
     19     batch_size=batch_size, shuffle=True)
     20 test_loader = torch.utils.data.DataLoader(
     21     datasets.MNIST('datasets/mnist_data/', train=False, transform=transforms.Compose([
     22         transforms.ToTensor(),
     23         # transforms.Normalize((0.1307,), (0.3081,))
     24     ])),
     25     batch_size=batch_size, shuffle=True)
     26 
     27 
     28 
     29 class MLP(nn.Module):
     30 
     31     def __init__(self):
     32         super(MLP, self).__init__()
     33 
     34         self.model = nn.Sequential(
     35             nn.Linear(784, 200),
     36             nn.LeakyReLU(inplace=True),
     37             nn.Linear(200, 200),
     38             nn.LeakyReLU(inplace=True),
     39             nn.Linear(200, 10),
     40             nn.LeakyReLU(inplace=True),
     41         )
     42 
     43     def forward(self, x):
     44         x = self.model(x)
     45 
     46         return x
     47 
     48 device = torch.device('cuda:0')
     49 net = MLP().to(device)
     50 optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
     51 criteon = nn.CrossEntropyLoss().to(device)
     52 
     53 viz = Visdom()
     54 
     55 viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
     56 viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',
     57                                                    legend=['loss', 'acc.']))
     58 global_step = 0
     59 
     60 for epoch in range(epochs):
     61 
     62     for batch_idx, (data, target) in enumerate(train_loader):
     63         data = data.view(-1, 28*28)
     64         data, target = data.to(device), target.cuda()
     65 
     66         logits = net(data)
     67         loss = criteon(logits, target)
     68 
     69         optimizer.zero_grad()
     70         loss.backward()
     71         # print(w1.grad.norm(), w2.grad.norm())
     72         optimizer.step()
     73 
     74         global_step += 1
     75         viz.line([loss.item()], [global_step], win='train_loss', update='append')
     76 
     77         if batch_idx % 100 == 0:
     78             print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
     79                 epoch, batch_idx * len(data), len(train_loader.dataset),
     80                        100. * batch_idx / len(train_loader), loss.item()))
     81 
     82 
     83     test_loss = 0
     84     correct = 0
     85     for data, target in test_loader:
     86         data = data.view(-1, 28 * 28)
     87         data, target = data.to(device), target.cuda()
     88         logits = net(data)
     89         test_loss += criteon(logits, target).item()
     90 
     91         pred = logits.argmax(dim=1)
     92         correct += pred.eq(target).float().sum().item()
     93 
     94     viz.line([[test_loss, correct / len(test_loader.dataset)]],
     95              [global_step], win='test', update='append')
     96     viz.images(data.view(-1, 1, 28, 28), win='x')
     97     viz.text(str(pred.detach().cpu().numpy()), win='pred',
     98              opts=dict(title='pred'))
     99 
    100     test_loss /= len(test_loader.dataset)
    101     print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
    102         test_loss, correct, len(test_loader.dataset),
    103         100. * correct / len(test_loader.dataset)))
  • 相关阅读:
    读后感
    每日总结
    融e学 一个专注于重构知识,培养复合型人才的平台【获取考试答案_破解】
    营销经验总结:如何才能提升h5游戏代入感?
    在通过《令人心动的offer》中,掌握学到的公司选择企业邮箱3大技巧!
    2020年的最后一个月来了!TOM 企业邮箱陪伴您度过!
    用户主动分享h5游戏的4大理由
    企业邮箱的作用是什么?企业邮箱有什么用处?
    电子邮箱系统哪家好?邮箱登陆入口是?
    专业外贸企业邮箱,企业邮箱退信怎么办?
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14060728.html
Copyright © 2011-2022 走看看