zoukankan      html  css  js  c++  java
  • MNIST手写识别Python(详细注释)

    import argparse   #参数解析包
    import torch      #张量、加减乘除等运算,相当于numpy
    import torch.nn as nn  #包括各种函数
    import torch.nn.functional as F #包括激活函数,损失函数等
    import torch.optim as optim     #优化器
    from torchvision import datasets, transforms  #数据集加载器和图像处理(预处理)
    #torchvision.datasets是继承torch.utils.data.Dataset的子类
    
    class Net(nn.Module):  #继承父类,定义卷积层和全连接层的输入输出深度和卷积核大小
        def __init__(self): #一般将网络中具有学习参数的层放入其中
            super(Net, self).__init__()  #调用父类方法__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1) #对输入信号进行二维卷积,参数:1:灰度图片通道 20:输出通道 5: 5x5的卷积核 1:步长
            self.conv2 = nn.Conv2d(20, 50, 5, 1) #上一层的输出通道数为该层的输入通道数
            self.fc1 = nn.Linear(4 * 4 * 50, 500)
            self.fc2 = nn.Linear(500, 10)  #最后通道数为10的目的是得到输出结果为0-9的各自概率大小
    
        def forward(self, x):   #定义前向传播,利用Autograd反向传播会自动实现
            """self.conv1(x) 输入:batch_size*1*28*28  1:第一个卷积层的输入通道, 28*28:图像像素大小 (x = batch_size)
                             输出:batch_size*20*(28-5+1)*(28-5+1) 20: 输出通道 5:卷积核大小 """
            x = F.relu(self.conv1(x))  #输入经过卷积层后通过激活函数作用实现非线性争强模型表达能力,relu()不改变输出shape:batch_size*20*24*24
            x = F.max_pool2d(x, 2, 2)  #2x2的最大池化层:特征降维,防止过拟合,获得更重要的相对位置。输出:batch_size*20*12*12
            x = F.relu(self.conv2(x))  #conv2(x)输入:batch_size*20*12*12 输出:batch_size*50*(12-5+1)*(12-5+1)
            x = F.max_pool2d(x, 2, 2)  #输入:batch_size*50*8*8 输出:batch_size*50*4*4
            x = x.view(-1, 4 * 4 * 50) #view相当于reshap,当不确定行数时设置为view(-1,x),x为确定的几列,所以该方法将数据reshape为一维
            x = F.relu(self.fc1(x))    #self.fc1(x)输入:batch_size*4*4*50 输出:batch_size*500
            x = self.fc2(x)            #self.fc2(x)输入:batch_size*500    输出:batch_size*10  
            #返回结果为10个类别各自的概率,dim=1指1维,相比softmax()归一化指数函数,log(softmax())计算速度更快,提高数值的稳定性
            return F.log_softmax(x, dim=1) 
    
    
    def train(args, model, device, train_loader, optimizer, epoch):
        model.train()   #训练模式,启用Batch Normalization(批量标准化)和Dlopout,防止过拟合
        for batch_idx, (data, target) in enumerate(train_loader):   #迭代训练数据的标签和数值
            data, target = data.to(device), target.to(device)       #选择训练设备GPU或CPU
            optimizer.zero_grad()  #梯度归零:因为在backward中梯度是累加的,所以每个batch_size进来都需要将上次的梯度归零
            output = model(data)   #将数据送入模型训练得到输出
            loss = F.nll_loss(output, target)  #通过输出和目标计算损失
            loss.backward()        #反向传播计算每个参数的梯度值,通过Autograd包实现
            optimizer.step()       #通过梯度下降更新参数
            if batch_idx % args.log_interval == 0:  #batch_idx每到args.log_interval(默认10)的整数倍时就打印训练轮次,损失等相关信息
                print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
                #loss.item():将loss里面的零维张量转换为浮点数
    
    
    def test(args, model, device, test_loader):
        model.eval()   #测试模式,关闭Batch Normalization(批量标准化)和Dlopout
        test_loss = 0  #初始化测试损失归零
        correct = 0    #初始化测试正确的数据个数为零
        with torch.no_grad():   #强制不进行计算图构建(计算过程的构建,以便梯度反向传播),节省内存
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()  # 计算损失并对损失求和,返回浮点型
                pred = output.argmax(dim=1, keepdim=True)  
                """torch.argmax()中dim的值表示忽略某个维度,输出其他维度的最大值索引,
                在二阶张量中,dim=1表示忽略列,比较每行得到每行的最大值的索引。
                keepdim=True表示保持输入输出维度相同"""
                
                correct += pred.eq(target.view_as(pred)).sum().item()
                #target.view_as(pred):将一维的target转换为和pred相同的二维张量
                #计算最大概率输出索引与目标类别是否相同,统计准确类别的总数目
    
        test_loss /= len(test_loader.dataset)  #计算测试集平均损失
    
        #打印测试集平均损失,准确率
        print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    
    
    def main():
        #argpares:参数解析模块,从sys.argv中解析参数
        parser = argparse.ArgumentParser(description='PyTorch MNIST Example')  #创建解析对象
        #添加命令行参数和对象
        parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                            help='input batch size for training (default: 64)')  # 添加批量大小参数,默认值64
        parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                            help='input batch size for testing (default: 1000)') #添加测试批量大小参数,默认值1000
        parser.add_argument('--epochs', type=int, default=10, metavar='N',
                            help='number of epochs to train (default: 10)')      #添加训练轮次参数,默认十次
        parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                            help='learning rate (default: 0.01)')                #添加学习率参数,默认值0.01
        parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                            help='SGD momentum (default: 0.5)')                  #添加动量参数(防止局部最优),默认0.5
        parser.add_argument('--no-cuda', action='store_true', default=False,
                            help='disables CUDA training')                       #参数表示CUDA不可用,默认False
        parser.add_argument('--seed', type=int, default=1, metavar='S',
                            help='random seed (default: 1)')                     #种子参数(确保每次产生的随机数相同),默认为1
        parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                            help='how many batches to wait before logging training status') #该参数确定多少个数据(默认10个)训练完成后打印对应参数信息
        parser.add_argument('--save-model', action='store_true', default=False,
                            help='For Saving the current Model')                 #保存模型参数
    
        args = parser.parse_args()  #参数实例化,从 sys.argv 中参数解析
        use_cuda = not args.no_cuda and torch.cuda.is_available() #判断是否有CUDA且是否可用
        torch.manual_seed(args.seed)  
        """为CPU设置随机种子(训练时参数初始化需要得到一组随机的数,
        但要保证这组随机数在每次训练时都保持不变,训练的结果才有意义),使得结果是确定的"""
        device = torch.device("cuda" if use_cuda else "cpu") #设备的选择CPU或GPU
        #如果use_cuda为Ture,则 kwargs = {'num_workers': 1, 'pin_memory': True},否则为空序列
        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        
        #通过数据接口将MNIST数据下载,并以batch size形式封装成torch
        #数据加载
        train_loader = torch.utils.data.DataLoader(
            #数据下载
            datasets.MNIST('../data', train=True, download=True,
                           #数据处理
                           transform=transforms.Compose([
                               #shape为(H,W,C通道)转换为(C,H,W)的tensor,并将数据除以255归一化[0,1]的数
                               transforms.ToTensor(),  
                               #正则化,通过平均值和标准差将数据标准化,并把取值范围规定到[-1,1]的数,降低模型复杂度,防止过拟合
                               transforms.Normalize((0.1307,), (0.3081,))  
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs) #将训练数据打乱
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs) #将测试数据打乱
    
        model = Net().to(device)  #模型实例化
        #确定优化器参数,学习率(步长)和动量(防止局部最优)。model.parameters优化器初始化
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    
        for epoch in range(1, args.epochs + 1): #训练轮次从1开始到epoch
            train(args, model, device, train_loader, optimizer, epoch) #训练
            test(args, model, device, test_loader)  #测试
        #模型保存
        if (args.save_model):
            #model.state_dict():以字典形式保存参数。 mnist_cnn.pt:保存路径
            torch.save(model.state_dict(), "mnist_cnn.pt")  
    
    
    if __name__ == '__main__':    main()
  • 相关阅读:
    delphi point数据类型
    Sql Server 2008 R2链接服务器Oracle数据库
    ORA-28000 账号被锁定的解决办法
    [Oracle] sqlplus / as sysdba ora-01031 insufficient privileges
    Oracle的操作系统认证(/ as sydba 登录方式)
    Delphi使用线程TThread查询数据库
    oracle
    统计字符串中字符出现的次数-Python
    Jmeter保存下载的文件
    如何在Microsoft Store上免费获得 HEIF、HEVC 编码支持
  • 原文地址:https://www.cnblogs.com/Uriel-w/p/15273137.html
Copyright © 2011-2022 走看看