zoukankan      html  css  js  c++  java
  • pytorch实现mnist识别实战

    函数

    画图、绘制曲线的函数:

    import torch
    from    matplotlib import pyplot as plt
    
    
    def plot_curve(data):
        fig = plt.figure()
        plt.plot(range(len(data)), data, color='blue')
        plt.legend(['value'], loc='upper right')
        plt.xlabel('step')
        plt.ylabel('value')
        plt.show()
    
    
    def plot_image(img, label, name):
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
            plt.title("{}: {}".format(name, label[i].item()))
            plt.xticks([])
            plt.yticks([])
        plt.show()
    

    第一步:加载数据集

    batch_size=512
    
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data',train=True,download=True,
                                   transform=torchvision.transforms.Compose(
                                       [
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize((0.1307,),(0.3081,))#这里的两个数字分别是数据集的均值是0.1307,标准差是0.3081
                                       ]
                                   )
                                   ),
        batch_size=batch_size,shuffle=True
    )
    
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data/',train=False,download=True,#是验证集所以train=False
                                   transform=torchvision.transforms.Compose(
                                       [
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize((0.1307,),(0.3081,))
                                       ]
                                   )
                                   ),
        batch_size=batch_size,shuffle=False#是验证集所以无需打乱,shuffle=False
    )
    
    

    第二步:创建网络模型

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
    
            #wx+b
            self.fc1 = nn.Linear(28*28,256)#256是自己根据经验随机设定的
            self.fc2 = nn.Linear(256,64)
            self.fc3 = nn.Linear(64,10)#注意这里的10是最后识别的类别数(最后一层的输出往往是识别的类别数)
    
        def forward(self, x):
            #x : [ b 1 28 28]有batch_size张图片,通道是1维灰度图像 图片大小是28*28
    
            #h1=relu(wx+b)
            x = F.relu(self.fc1(x))#使用relu非线性激活函数包裹
            x = F.relu(self.fc2(x))
            x = F.softmax(self.fc3(x))#由于是多类别识别,所以使用softmax函数
            #x = self.fc3(x)
            return x
    
    

    第三步:训练

    net = Net()
    optimizer = optim.SGD(net.parameters(),lr=0.1,momentum=0.9)
    train_loss = []
    
    
    for epoch in range(5):
        for batch_idx,(x,y) in enumerate(train_loader):#enumerate表示在数据前面加上序号组成元组,默认序号从0开始
    
            # x :[512 1 28 28]   y : [512]
    
            #由于这里的x维度为[512 1 28 28],但是在网络中第一层就是一个全连接层,维度只能是[b,feature(784)],所以要把x打平
            #将前面多维度的tensor展平成一维
    
            # 卷积或者池化之后的tensor的维度为(batchsize,channels,x,y),其中x.size(0)
            # 指batchsize的值,最后通过x.view(x.size(0), -1)
            # 将tensor的结构转换为了(batchsize, channels * x * y),即将(channels,x,y)拉直,然后就可以和fc层连接了
    
            x = x.view(x.size(0),28*28)
            #输出之后的维度变为[512,10]
            out=net(x)
            #使用交叉熵损失
            loss = F.cross_entropy(out,y)
    
            #清零梯度——计算梯度——更新梯度
    
            #要进行梯度的清零
            optimizer.zero_grad()
    
            loss.backward()
            #功能是: w` = w-lr*grad
            optimizer.step()
    
            train_loss.append(loss.item())#将loss保存在trainloss中,而loss.item()表示将tensor 的类型转换为数值类型
    
            #打印loss
            if batch_idx % 10 == 0:
                print(epoch,batch_idx,loss.item())
    
    

    第四步:验证

    plot_curve(train_loss)
    
    total_correct = 0
    for x, y in test_loader:
        x = x.view(x.size(0),28*28)
        out = net(x)
        #out :[512,10]
        pred = out.argmax(dim = 1)
        correct = pred.eq(y).sum().float().item()#当前批次识别对的个数
        total_correct+= correct
    
    total_number = len(test_loader.dataset)
    acc = total_correct / total_number
    print('test acc',acc)
    
    
    x,y = next(iter(test_loader))
    out = net(x.view(x.size(0),28*28))
    pred = out.argmax(dim=1)
    plot_image(x,pred,'test')
    

    训练中loss的下降趋势:
    在这里插入图片描述
    测试效果:
    在这里插入图片描述
    开始的learning rate设置的为0.01,最终的acc为0.86,将learning rate改为0.1后,acc为0.96
    在这里插入图片描述

  • 相关阅读:
    Ubuntu adb devices :???????????? no permissions (verify udev rules) 解决方法
    ubuntu 关闭显示器的命令
    ubuntu android studio kvm
    ubuntu 14.04版本更改文件夹背景色为草绿色
    ubuntu 创建桌面快捷方式
    Ubuntu 如何更改用户密码
    ubuntu 14.04 返回到经典桌面方法
    ubuntu 信使(iptux) 创建桌面快捷方式
    Eclipse failed to get the required ADT version number from the sdk
    Eclipse '<>' operator is not allowed for source level below 1.7
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13592027.html
Copyright © 2011-2022 走看看