zoukankan      html  css  js  c++  java
  • PytorchMNIST(使用Pytorch进行MNIST字符集识别任务)

      都说MNIST相当于机器学习界的Hello World。最近加入实验室,导师给我们安排了一个任务,但是我才刚刚入门呐!!没办法,只能从最基本的学起。

      Pytorch是一套开源的深度学习张量库。或者我倾向于把它当成一个独立的深度学习框架。为了写这么一个"Hello World"。查阅了不少资料,也踩了不少坑。不过同时也学习了不少东西,下面我把我的代码记录下来,希望能够从中受益更多,同时帮助其他对Pytorch感兴趣的人。代码的注释中有不对的地方欢迎批评指正。

      代码进行了注释,应该很方便阅读。 dependences: numpy torch torchvision python3 使用pip安装即可。

     1 # encoding: utf-8
     2 import torch
     3 import torch.nn as nn
     4 import torch.nn.functional as F #加载nn中的功能函数
     5 import torch.optim as optim #加载优化器有关包
     6 import torch.utils.data as Data
     7 from torchvision import datasets,transforms #加载计算机视觉有关包
     8 from torch.autograd import Variable
     9 
    10 BATCH_SIZE = 64
    11 
    12 #加载torchvision包内内置的MNIST数据集 这里涉及到transform:将图片转化成torchtensor
    13 train_dataset = datasets.MNIST(root='~/data/',train=True,transform=transforms.ToTensor(),download=True)
    14 test_dataset = datasets.MNIST(root='~/data/',train=False,transform=transforms.ToTensor())
    15 
    16 #加载小批次数据,即将MNIST数据集中的data分成每组batch_size的小块,shuffle指定是否随机读取
    17 train_loader = Data.DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
    18 test_loader = Data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False)
    19 
    20 #定义网络模型亦即Net 这里定义一个简单的全连接层784->10
    21 class Model(nn.Module):
    22     def __init__(self):
    23         super(Model,self).__init__()
    24         self.linear1 = nn.Linear(784,10)
    25 
    26     def forward(self,X):
    27         return F.relu(self.linear1(X))
    28 
    29 model = Model() #实例化全连接层
    30 loss = nn.CrossEntropyLoss() #损失函数选择,交叉熵函数
    31 optimizer = optim.SGD(model.parameters(),lr = 0.1)
    32 num_epochs = 5
    33 
    34 #以下四个列表是为了可视化(暂未实现)
    35 losses = [] 
    36 acces = []
    37 eval_losses = []
    38 eval_acces = []
    39 
    40 for echo in range(num_epochs):
    41     train_loss = 0   #定义训练损失
    42     train_acc = 0    #定义训练准确度
    43     model.train()    #将网络转化为训练模式
    44     for i,(X,label) in enumerate(train_loader):     #使用枚举函数遍历train_loader
    45         X = X.view(-1,784)       #X:[64,1,28,28] -> [64,784]将X向量展平
    46         X = Variable(X)          #包装tensor用于自动求梯度
    47         label = Variable(label)
    48         out = model(X)           #正向传播
    49         lossvalue = loss(out,label)         #求损失值
    50         optimizer.zero_grad()       #优化器梯度归零
    51         lossvalue.backward()    #反向转播,刷新梯度值
    52         optimizer.step()        #优化器运行一步,注意optimizer搜集的是model的参数
    53         
    54         #计算损失
    55         train_loss += float(lossvalue)      
    56         #计算精确度
    57         _,pred = out.max(1)
    58         num_correct = (pred == label).sum()
    59         acc = int(num_correct) / X.shape[0]
    60         train_acc += acc
    61 
    62     losses.append(train_loss / len(train_loader))
    63     acces.append(train_acc / len(train_loader))
    64     print("echo:"+' ' +str(echo))
    65     print("lose:" + ' ' + str(train_loss / len(train_loader)))
    66     print("accuracy:" + ' '+str(train_acc / len(train_loader)))
    67     eval_loss = 0
    68     eval_acc = 0
    69     model.eval() #模型转化为评估模式
    70     for X,label in test_loader:
    71         X = X.view(-1,784)
    72         X = Variable(X)
    73         label = Variable(label)
    74         testout = model(X)
    75         testloss = loss(testout,label)
    76         eval_loss += float(testloss)
    77 
    78         _,pred = testout.max(1)
    79         num_correct = (pred == label).sum()
    80         acc = int(num_correct) / X.shape[0]
    81         eval_acc += acc
    82 
    83     eval_losses.append(eval_loss / len(test_loader))
    84     eval_acces.append(eval_acc / len(test_loader))
    85     print("testlose: " + str(eval_loss/len(test_loader)))
    86     print("testaccuracy:" + str(eval_acc/len(test_loader)) + '
    ')

    运行后的结果如下:

        我们在上面的代码中,将图片对应的Pytorchtensor展平,并通过一个全连接层,仅仅是这样就达到了90%以上的准确率。如果使用卷积层,正确率有望达到更高。

      代码并不完备,还可以增加visualize和predict功能,等我学到更多知识后,有待后续添加。  

  • 相关阅读:
    Python的网络编程[0] -> socket[1] -> socket 模块
    Python的网络编程[0] -> socket[0] -> socket 与 TCP / UDP
    Python的功能模块[4] -> pdb/ipdb -> 实现 Python 的单步调试
    Python的功能模块[3] -> binascii -> 编码转换
    Python的功能模块[2] -> abc -> 利用 abc 建立抽象基类
    Python的功能模块[1] -> struct -> struct 在网络编程中的使用
    Python的功能模块[0] -> wmi -> 获取 Windows 内部信息
    Python的程序结构[8] -> 装饰器/Decorator -> 装饰器浅析
    Python的程序结构[7] -> 生成器/Generator -> 生成器浅析
    Python的程序结构[6] -> 迭代器/Iterator -> 迭代器浅析
  • 原文地址:https://www.cnblogs.com/chester-cs/p/11544898.html
Copyright © 2011-2022 走看看