zoukankan      html  css  js  c++  java
  • notMNIST 数据集pyTorch分类

    notMNIST数据集分类

    简介

    notMNIST数据集 是于2011公布的,可以认为是MNIST数据集地一个加强版本。数据集包含了从A到J十个字母,由large与small两个子集组成。其中samll数据集是经过手工清理的,包含19k个图片,误分类率越为0.5%,large数据集是未经过手工清理的,包含500k张图片,误分类率约为6.5%。

    作者推荐在large数据集上训练网络,在small数据集上测试网络。可以将large数据集分为5/6和1/6,使用5/6做training,1/6做validation。

    在该网站上网友做的正确率较高的再97%到98%,我自己使用resnet最高达到了98.04%。接下来就说一下我做的步骤。

    分类

    数据预处理

    一步要解决的是数据集的加载。原始数据集是一些很小地图片,一个一个地从磁盘中加载无疑会拖慢模型训练的速度。最好的方式就是将所有数据都加载到内存中。因此,可以将数据加载到内存中,并将标准化之后的数据以二进制文件使用pickle保存到磁盘。这样,每次从磁盘中读取数据可以直接读取二进制文件,否则每次读取数据集中地图片都会耗时很久。

    import os, cv2, pickle
    import numpy as np
    rootdir = 'D:/DataSet/notMNIST/notMNIST_large'
    classlist = os.listdir(rootdir)
    imgLabels = []
    imgNames = []
    for classes in classlist:
        imgFolder = os.path.join(rootdir, classes)
        imgnames = os.listdir(imgFolder)
        imgLabels.extend([idxName[classes]] * len(imgnames))
        imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])
     
    imgs = np.zeros((len(imgLabels), 28, 28), np.float)
    idx = 0
    print('loading training data......')
    for imgname in imgNames:
        try:
            img = cv2.imread(imgname, 0).astype(np.float) / 255.0
            imgs[idx, :, :] = img
            idx += 1
        except AttributeError:
            np.delete(imgs, idx, axis=0)
    print('loading training data finished, %d samples' % imgs.shape[0])
    
    train_mean, train_std = np.mean(imgs), np.std(imgs)
    print('%.6f, %6f', train_mean, train_std)
    imgs = (imgs - train_mean) / train_std
    data = {'images': imgs, 'labels': imgLabels}
    
    with open('D:/DataSet/notMNIST/trainset', 'wb') as f:
        pickle.dump(data, f)
    print('train set finished')
    
    
    rootdir = 'D:/DataSet/notMNIST/notMNIST_small'
    classlist = os.listdir(rootdir)
    imgLabels = []
    imgNames = []
    for classes in classlist:
        imgFolder = os.path.join(rootdir, classes)
        imgnames = os.listdir(imgFolder)
        imgLabels.extend([idxName[classes]] * len(imgnames))
        imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])
    
    imgs = np.zeros((len(imgLabels), 28, 28), np.float)
    idx = 0
    print('loading test data......')
    for imgname in imgNames:
        try:
            img = cv2.imread(imgname, 0).astype(np.float) / 255.0
            imgs[idx, :, :] = img
            idx += 1
        except AttributeError:
            np.delete(imgs, idx, axis=0)
    print('loading test data finished. % d samples' % imgs.shape[0])
    
    train_mean, train_std = np.mean(imgs), np.std(imgs)
    imgs = (imgs - train_mean) / train_std
    data = {'images': imgs, 'labels': imgLabels}
    
    with open('D:/DataSet/notMNIST/testset', 'wb') as f:
        pickle.dump(data, f)
    print('test set finished')
    

    使用try语句地原因是,在读取过程中可能出现一些错误。

  • 相关阅读:
    C#水晶报表的分页统计字段
    ymPrompt消息提示组件js实现
    C#委托学习 原文推荐:http://www.cnblogs.com/warensoft/archive/2010/03/19/1689806.html?login=1#commentform
    C#之winfrom打印图片
    TreeView控件如何设置节点显示与隐藏,主要是用来做后台权限,没有权限的就隐藏,有权限的就显示?
    C#多线程间同步实例 原文:http://blog.csdn.net/zhoufoxcn/article/details/2453803
    C#反射的应用 原文摘自:http://blog.csdn.net/Tsapi/article/details/6234205
    C#编写的winform程序打包方法
    虚拟机下的CentOS环境中安装Node.js和npm
    RequireJS模块化与GruntJS构建
  • 原文地址:https://www.cnblogs.com/zi-wang/p/9891245.html
Copyright © 2011-2022 走看看