zoukankan      html  css  js  c++  java
  • 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别

    开始学习深度学习,先来一个手写数字的程序

    import numpy as np
    import os
    import codecs
    import torch
    from PIL import Image
    
    lr = 0.01
    momentum = 0.5
    epochs = 10
    
    def get_int(b):
        return int(codecs.encode(b, 'hex'), 16)
    
    def read_label_file(path):
        with open(path, 'rb') as f:
            data = f.read()
            assert get_int(data[:4]) == 2049
            length = get_int(data[4:8])
            parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
            return torch.from_numpy(parsed).view(length).long()
    
    def read_image_file(path):
        with open(path, 'rb') as f:
            data = f.read()
            assert get_int(data[:4]) == 2051
            length = get_int(data[4:8])
            num_rows = get_int(data[8:12])
            num_cols = get_int(data[12:16])
            images = []
            parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
            return torch.from_numpy(parsed).view(length, num_rows, num_cols)
    
    def loadmnist(path, kind='train'):
        labels_path = os.path.join(path, 'mnist' ,'%s-labels.idx1-ubyte' % kind)
        images_path = os.path.join(path,'mnist' ,'%s-images.idx3-ubyte' % kind)
        
        labels = read_label_file(labels_path)
        images = read_image_file(images_path)
        return images, labels
    
    import torch.utils.data as data
    import torchvision.transforms as transforms
    
    class Loader(data.Dataset):
        def __init__(self, root, label, transforms):
            self.imgs = []
            
            imgs,labels = loadmnist(root, label)
            
            self.imgs = imgs
            self.labels = labels
    
            self.transforms = transforms
            
        def __getitem__(self, index):
            img, label = self.imgs[index],self.labels[index]
            
            img = Image.fromarray(img.numpy(), mode='L')
            if self.transforms:
                img = self.transforms(img)
            
            return img, label
        
        def __len__(self):
            return len(self.imgs)
    
    def getTrainDataset():
        return Loader('d:\work\yoho\dl\dl-study\chapter0\', 'train', transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ]))
    
    def getTestDataset():
        return  Loader('d:\work\yoho\dl\dl-study\chapter0\', 't10k', transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ]))
    
    import torch as t
    import torch.nn as nn
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.features = nn.Sequential(
                nn.Conv2d(1, 10, kernel_size=5),
                nn.MaxPool2d(2),
                nn.ReLU(inplace=True),
                nn.Conv2d(10, 20, kernel_size=5),
                nn.Dropout2d(),
                nn.MaxPool2d(2),
                nn.ReLU(inplace=True),
            )
            
            self.classifier = nn.Sequential(
                nn.Linear(320, 50),
                nn.ReLU(inplace=True),
                nn.Dropout(),
                nn.Linear(50, 10),
                nn.LogSoftmax(dim=1)
            )
        
        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
            
            return x
    
    net = Net()
    
    import torch.optim as optim
    from torch.nn.modules import loss
    
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    criterion = loss.CrossEntropyLoss()
    
    train_dataset = getTrainDataset()
    test_dataset = getTestDataset()
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)
    
    from torch.autograd import Variable as V
    
    def train(epoch):
        for i, (inputs, labels) in enumerate(train_loader):
            inputs_var, labels_var = V(inputs), V(labels)
            
            outputs = net(inputs_var)
            losses = criterion(outputs, labels_var)
            
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
    def test(epoch):
        for i, (inputs, labels) in enumerate(test_loader):
            inputs_var = V(inputs)
            
            outputs = net(inputs_var)
            
            _, pred = outputs.data.topk(5, 1, True, True)
            
            batch_size = labels.size(0)
            pred = pred.t()
            corrent = pred.eq(labels.view(1, -1).expand_as(pred))
            
            res = []
            for k in (1,5):
                correct_k = corrent[:k].view(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
            print('{} {} top1 {} top5 {}'.format(epoch, i ,res[0][0], res[1][0]))
    
    def main():
        for epoch in range(0, epochs):
            train(epoch)
            test(epoch)
    
    main()
    

    学习之后的,正确率很高,这种问题对于深度学习已经解决了。

  • 相关阅读:
    HDU2602:Bone Collector
    HDU5773:The All-purpose Zero
    LightOJ 1275:Internet Service Providers
    8.SpringMVC拦截器
    7.SpringMVC和Ajax技术
    Tomcat什么时候需要restart,redeploy,update classes and resources
    6.SpringMVC的JSON讲解
    5.SpringMVC数据处理
    4.SpringMVC的结果跳转方式
    3.SpringMVC的Controller 及 RestFul风格
  • 原文地址:https://www.cnblogs.com/htoooth/p/8651259.html
Copyright © 2011-2022 走看看