zoukankan      html  css  js  c++  java
  • pytorch进行mnist识别实战

    mnist实战

    开始使用简单的全连接层进行mnist手写数字的识别,识别率最高能到95%,而使用两层卷积后再全连接,识别率能达到99%

    全连接:

    import torch
    from torch import nn
    from torch.nn import functional as F
    from torch import optim
    import torchvision
    from    matplotlib import pyplot as plt
    from torch.optim.lr_scheduler import StepLR
    
    #step 1:load dataset
    
    def plot_image(img, label, name):
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
            plt.title("{}: {}".format(name, label[i].item()))
            plt.xticks([])
            plt.yticks([])
        plt.show()
    
    def plot_curve(data):
        fig = plt.figure()
        plt.plot(range(len(data)), data, color='blue')
        plt.legend(['value'], loc='upper right')
        plt.xlabel('step')
        plt.ylabel('value')
        plt.show()
    
    
    batch_size=512
    
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data',train=True,download=True,
                                   transform=torchvision.transforms.Compose(
                                       [
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize((0.1307,),(0.3081,))#这里的两个数字分别是数据集的均值是0.1307,标准差是0.3081
                                       ]
                                   )
                                   ),
        batch_size=batch_size,shuffle=True
    )
    
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data/',train=False,download=True,#是验证集所以train=False
                                   transform=torchvision.transforms.Compose(
                                       [
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize((0.1307,),(0.3081,))
                                       ]
                                   )
                                   ),
        batch_size=batch_size,shuffle=False#是验证集所以无需打乱,shuffle=False
    )
    
    # x,y = next(iter(train_loader))
    # plot_image(x,y,'example')
    
    
    #step2: create network
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
    
            #wx+b
            self.fc1 = nn.Linear(28*28,256)#256是自己根据经验随机设定的
            self.fc2 = nn.Linear(256,64)
            self.fc3 = nn.Linear(64,10)#注意这里的10是最后识别的类别数(最后一层的输出往往是识别的类别数)
    
        def forward(self, x):
            #x : [ b 1 28 28]有batch_size张图片,通道是1维灰度图像 图片大小是28*28
    
            #h1=relu(wx+b)
            x = F.relu(self.fc1(x))#使用relu非线性激活函数包裹
            x = F.relu(self.fc2(x))
            x = F.softmax(self.fc3(x))#由于是多类别识别,所以使用softmax函数
            #x = self.fc3(x)
            return x
    
    net = Net()
    optimizer = optim.Adam(net.parameters())
    train_loss = []
    
    
    
    
    for epoch in range(5):
    
        for batch_idx,(x,y) in enumerate(train_loader):#enumerate表示在数据前面加上序号组成元组,默认序号从0开始
    
            # x :[512 1 28 28]   y : [512]
    
            #由于这里的x维度为[512 1 28 28],但是在网络中第一层就是一个全连接层,维度只能是[b,feature(784)],所以要把x打平
            #将前面多维度的tensor展平成一维
    
            # 卷积或者池化之后的tensor的维度为(batchsize,channels,x,y),其中x.size(0)
            # 指batchsize的值,最后通过x.view(x.size(0), -1)
            # 将tensor的结构转换为了(batchsize, channels * x * y),即将(channels,x,y)拉直,然后就可以和fc层连接了
    
            x = x.view(x.size(0),28*28)
            #输出之后的维度变为[512,10]
            out=net(x)
            #使用交叉熵损失
            loss = F.cross_entropy(out,y)
    
            #清零梯度——计算梯度——更新梯度
    
            #要进行梯度的清零
            optimizer.zero_grad()
    
            loss.backward()
            #功能是: w` = w-lr*grad
            optimizer.step()
    
            train_loss.append(loss.item())#将loss保存在trainloss中,而loss.item()表示将tensor 的类型转换为数值类型
    
            #打印loss
            if batch_idx % 10 == 0:
                print(epoch,batch_idx,loss.item())
    
    
    plot_curve(train_loss)
    
    total_correct = 0
    for x, y in test_loader:
        x = x.view(x.size(0),28*28)
        out = net(x)
        #out :[512,10]
        pred = out.argmax(dim = 1)
        correct = pred.eq(y).sum().float().item()#当前批次识别对的个数
        total_correct+= correct
    
    total_number = len(test_loader.dataset)
    acc = total_correct / total_number
    print('test acc',acc)
    
    
    x,y = next(iter(test_loader))
    out = net(x.view(x.size(0),28*28))
    pred = out.argmax(dim=1)
    plot_image(x,pred,'test')
    
    #optimizer = optim.SGD(net.parameters(),lr=0.1,momentum=0.9)
    #test acc 0.8783
    
    #optimizer = optim.Adam(net.parameters())
    #test acc 0.9574
    

    加入卷积:

    import torch
    import argparse
    import torch.nn as nn
    import matplotlib.pyplot as plt
    import torch.optim as optim
    import torch.nn.functional as F
    from torchvision import datasets,transforms
    from torch.optim.lr_scheduler import StepLR
    
    class Net(nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            self.conv1 = nn.Conv2d(1,32,3,1)
            self.conv2 = nn.Conv2d(32,64,3,1)
            self.dropout1 = nn.Dropout2d(0.25)
            self.dropout2 = nn.Dropout2d(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)
    
        def forward(self,x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            #print(x.shape)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            #print(x.shape)
            x = torch.flatten(x,1)
            #print(x.shape)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.softmax(x)
            return output
    
    #用来查看经过conv之后进入全连接层的维度
    # def main():
    #     net = Net()
    #
    #     tmp = torch.rand(10,1,28,28)
    #     out = net.forward(tmp)
    #
    #
    # if __name__=='__main__':
    #     main()
    # torch.Size([10, 64, 24, 24])
    # torch.Size([10, 64, 12, 12])
    # torch.Size([10, 9216])
    
    def plot_image(img, label, name):
        fig = plt.figure()
        for i in range(6):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
            plt.title("{}: {}".format(name, label[i].item()))
            plt.xticks([])
            plt.yticks([])
        plt.show()
    
    def train(args,model,device,train_loader,optimizer,epoch):
        model.train()#进入训练模式来激活dropout层、正则化等的使用
        for batch_idx,(data,target) in enumerate(train_loader):
            data,target = data.to(device),target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output,target)
            loss.backward()
            optimizer.step()
            if batch_idx % args.log_interval ==0:
                print('train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
                if args.dry_run:
                    break
    
    def test(model, device, test_loader):
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
    
        test_loss /= len(test_loader.dataset)
    
        print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    
    
    def main():
        # Training settings
        parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
        parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                            help='input batch size for training (default: 64)')
        parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                            help='input batch size for testing (default: 1000)')
        parser.add_argument('--epochs', type=int, default=14, metavar='N',
                            help='number of epochs to train (default: 14)')
        parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                            help='learning rate (default: 1.0)')
        parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                            help='Learning rate step gamma (default: 0.7)')
        parser.add_argument('--no-cuda', action='store_true', default=False,
                            help='disables CUDA training')
        parser.add_argument('--dry-run', action='store_true', default=False,
                            help='quickly check a single pass')
        parser.add_argument('--seed', type=int, default=1, metavar='S',
                            help='random seed (default: 1)')
        parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                            help='how many batches to wait before logging training status')
        parser.add_argument('--save-model', action='store_true', default=True,
                            help='For Saving the current Model')
        args = parser.parse_args()
        use_cuda = not args.no_cuda and torch.cuda.is_available()
    
        torch.manual_seed(args.seed)
    
        device = torch.device("cuda" if use_cuda else "cpu")
    
        kwargs = {'batch_size': args.batch_size}
        if use_cuda:
            kwargs.update({'num_workers': 1,
                           'pin_memory': True,
                           'shuffle': True},
                         )
    
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
            ])
        dataset1 = datasets.MNIST('', train=True, download=False,
                           transform=transform)
        dataset2 = datasets.MNIST('', train=False,
                           transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
        test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)
    
        model = Net().to(device)
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
        #optimizer = optim.Adam(model.parameters())
    
        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
        for epoch in range(1, args.epochs + 1):
            train(args, model, device, train_loader, optimizer, epoch)
            test(model, device, test_loader)
            scheduler.step()
    
    
        if args.save_model:
            torch.save(model.state_dict(), "mnist_cnn.pt")
    
        model.load_state_dict(torch.load('mnist_cnn.pt'))
    
        #观察测试结果
        for i in range(5):
            x, y = next(iter(test_loader))
            x,y = x.to(device),y.to(device)
            out = model(x)
            pred = out.argmax(dim=1)
            plot_image(x.cpu(), pred.cpu(), 'test')
    
    
    
    
    if __name__ == '__main__':
        main()
    
    
    
    #使用Adadelta 设置lr衰减
    #Test set: Average loss: 1.4739, Accuracy: 9873/10000 (99%)
    
    #使用SGD优化器,learning rate0.1 ,未设置lr的衰减
    #Test set: Average loss: 1.4735, Accuracy: 9880/10000 (99%)
    
    #使用Adam优化器,lr默认使用Adam的默认值0.001(使用0.1loss下不来) 未设置lr的衰减
    #Test set: Average loss: 1.4749, Accuracy: 9862/10000 (99%)
    
    
    
  • 相关阅读:
    mysql BETWEEN操作符 语法
    mysql IN操作符 语法
    mysql LIKE通配符 语法
    mysql TOP语句 语法
    mysql DELETE语句 语法
    mysql Update语句 语法
    mysql INSERT语句 语法
    mysql ORDER BY语句 语法
    mysql OR运算符 语法
    mysql AND运算符 语法
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13671528.html
Copyright © 2011-2022 走看看