zoukankan      html  css  js  c++  java
  • 1027-pytorch之手写体识别

    pytorch手写体识别

    代码

    import torch
    from torch import nn
    from torch.nn import functional as F
    from torch import optim
    
    import torchvision
    from matplotlib import pyplot as plt
    
    from torch_study.lesson5_minist_train.utils import plot_curve, plot_image, plt, one_hot
    
    batch_size = 512
    
    # step1. load dataset
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),
        batch_size=batch_size, shuffle=True)
    
    #batch_size为一次训练多少,shuffle是否打散
    
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),
        batch_size=batch_size, shuffle=False)
    
    #查看数据维度
    x,y = next(iter(train_loader))
    print(x.shape,y.shape,x.min(),x.max())
    plot_image(x,y,'image sample')
    
    
    class Net(nn.Module):
    
        def __init__(self):
            super(Net,self).__init__()
    
            #wx+b
            self.fc1 = nn.Linear(28*28,256)
            self.fc2 = nn.Linear(256,64)
            self.fc3 = nn.Linear(64,10)
    
        def forward(self,x):
            # x:[b,1,28,28]
            # h1=relu(xw1+b1)
            x=F.relu(self.fc1(x))
            # h2=relu(h1*w2+b2)
            x=F.relu(self.fc2(x))
            # h3=h2*w3+b3
            x=self.fc3(x)
    
            return x
    
    
    net = Net()
    # [w1,b1,w2,b2,w3,b3]  momentum动量
    optimizer = optim.SGD(net.parameters(),lr=0.05,momentum=0.9)
    
    train_loss = []
    
    #对数据集迭代3次
    for epoch in range(3):
        #从数据集中sample出一个batch_size图片
        for batch_idx ,(x,y) in enumerate(train_loader):
    
            #x:[b,1,28,28] ,y[512]
            #[b,1,28,28] => [b,feature]
            x=x.view(x.size(0),28*28)
            # => [b,10]
            out = net(x)
            #[b,10]
            y_onehot = one_hot(y)
            #loss = mse(out,y_onehot)
            loss = F.cross_entropy(out,y_onehot)
            #清零梯度
            optimizer.zero_grad()
            #计算梯度
            loss.backward()
            #w'=w-lr*grad,更新梯度
            optimizer.step()
    
            train_loss.append(loss.item())
    
            if batch_idx %10 ==0:
                print(epoch,batch_idx,loss.item())
    
    #绘制损失曲线
    plot_curve(train_loss)
    # we get optimal [w1,b1,w2,b2,w3,b3]
    
    #对测试集进行判断
    total_corrrect=0
    for x,y in test_loader:
        x=x.view(x.size(0),28*28)
        out=net(x)
        # out:[b,10] => pred: [b]
        pred = out.argmax(dim=1)
        correct = pred.eq(y).sum().float().item()
        total_corrrect+=correct
    
    total_num = len(test_loader.dataset)
    acc = total_corrrect / total_num
    print('test acc:',acc)
    
    x,y=next(iter(test_loader))
    
    out = net(x.view(x.size(0),28*28))
    pred = out.argmax(dim=1)
    plot_image(x,pred,'test')

    结果

     

     

     

     模型提升

    增加模型层数

    调整loss损失计算函数

    调整学习率,训练大小batch_size

  • 相关阅读:
    健康检查详解:机制、配置、对比、实操
    制作自签名证书
    常用的UML建模
    UML建模更好的表达产品逻辑
    常用的UML建模
    UML建模图实战笔记
    领域驱动设计学习之路—DDD的原则与实践
    DDD领域驱动设计理论篇
    WAN、LAN、WLAN三种网口的区别
    新生代Eden与两个Survivor区的解释
  • 原文地址:https://www.cnblogs.com/xiaofengzai/p/15473485.html
Copyright © 2011-2022 走看看