zoukankan      html  css  js  c++  java
  • pytorch(二十一):交叉验证

    一、K折交叉验证

    将训练集分成K份,一份做验证集,其他做测试集。这K份都有机会做验证集

     

     

    二、代码

      1 import torch
      2 import torch.nn as nn
      3 import torchvision 
      4 from torchvision import datasets,transforms
      5 from torch.nn import functional as F
      6 import torch.optim as optim
      7 
      8 
      9 batch_size = 200
     10 learning_rate  = 1e-2
     11 epochs = 10
     12 train_db =  datasets.MNIST('datasets/mnist_data',
     13                 train=True,
     14                 download=True,
     15                 transform=torchvision.transforms.Compose([
     16                 torchvision.transforms.ToTensor(),                       # 数据类型转化
     17                 torchvision.transforms.Normalize((0.1307, ), (0.3081, )) # 数据归一化处理
     18     ]))
     19 
     20 train_loader = torch.utils.data.DataLoader(
     21         train_db,
     22         batch_size = batch_size,
     23         shuffle = True)
     24 
     25 test_db = datasets.MNIST('datasets/mnist_data/',
     26                 train=False,
     27                 download=True,
     28                 transform=torchvision.transforms.Compose([
     29                 torchvision.transforms.ToTensor(),
     30                 torchvision.transforms.Normalize((0.1307, ), (0.3081, ))
     31     ]))
     32 
     33 test_loader = torch.utils.data.DataLoader(
     34         test_db,
     35         batch_size = batch_size,
     36         shuffle = True
     37 )
     38 
     39 print('train:', len(train_db), 'test:', len(test_db))
     40 train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
     41 print('db1:', len(train_db), 'db2:', len(val_db))
     42 train_loader = torch.utils.data.DataLoader(
     43     train_db,
     44     batch_size=batch_size, shuffle=True)
     45 val_loader = torch.utils.data.DataLoader(
     46     val_db,
     47     batch_size=batch_size, shuffle=True)
     48 
     49 class MLP(nn.Module):
     50 
     51     def __init__(self):
     52         super(MLP, self).__init__()
     53 
     54         self.model = nn.Sequential(
     55             nn.Linear(784, 200),
     56             nn.LeakyReLU(inplace=True),
     57             nn.Linear(200, 200),
     58             nn.LeakyReLU(inplace=True),
     59             nn.Linear(200, 10),
     60             nn.LeakyReLU(inplace=True),
     61         )
     62 
     63     def forward(self, x):
     64         x = self.model(x)
     65 
     66         return x
     67 
     68 device = torch.device('cuda:0')
     69 net = MLP().to(device)
     70 optimizer = optim.SGD(net.parameters(), lr=learning_rate)
     71 criteon = nn.CrossEntropyLoss().to(device)
     72 
     73 for epoch in range(epochs):
     74 
     75     for batch_idx, (data, target) in enumerate(train_loader):
     76         data = data.view(-1, 28*28)
     77         data, target = data.to(device), target.cuda()
     78 
     79         logits = net(data)
     80         loss = criteon(logits, target)
     81 
     82         optimizer.zero_grad()
     83         loss.backward()
     84         # print(w1.grad.norm(), w2.grad.norm())
     85         optimizer.step()
     86 
     87         if batch_idx % 100 == 0:
     88             print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
     89                 epoch, batch_idx * len(data), len(train_loader.dataset),
     90                        100. * batch_idx / len(train_loader), loss.item()))
     91 
     92 
     93     test_loss = 0
     94     correct = 0
     95     for data, target in val_loader:
     96         data = data.view(-1, 28 * 28)
     97         data, target = data.to(device), target.cuda()
     98         logits = net(data)
     99         test_loss += criteon(logits, target).item()
    100 
    101         pred = logits.data.max(1)[1]
    102         correct += pred.eq(target.data).sum()
    103 
    104     test_loss /= len(val_loader.dataset)
    105     print('
    VAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
    106         test_loss, correct, len(val_loader.dataset),
    107         100. * correct / len(val_loader.dataset)))
    108 
    109 
    110 
    111 test_loss = 0
    112 correct = 0
    113 for data, target in test_loader:
    114     data = data.view(-1, 28 * 28)
    115     data, target = data.to(device), target.cuda()
    116     logits = net(data)
    117     test_loss += criteon(logits, target).item()
    118 
    119     pred = logits.data.max(1)[1]
    120     correct += pred.eq(target.data).sum()
    121 
    122 test_loss /= len(test_loader.dataset)
    123 print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
    124     test_loss, correct, len(test_loader.dataset),
    125     100. * correct / len(test_loader.dataset)))
  • 相关阅读:
    仿照京东做的一个鼠标移上去的图片文字说明效果
    js 之 复制一段代码
    自己练习了一个弹出框
    用jq 做了一个排序
    做了一个类似天猫鼠标经过icon的动画,记录一下
    一行代码写一个轮播,想了好久,感觉这样可以。
    一个小例子,全选复选框
    仿照淘宝首页做的一个高度伪对齐demo
    《挑战程序设计竞赛》2.2 贪心法-区间 POJ2376 POJ1328 POJ3190
    《挑战程序设计竞赛》2.1 穷竭搜索 POJ2718 POJ3187 POJ3050 AOJ0525
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14060699.html
Copyright © 2011-2022 走看看