zoukankan      html  css  js  c++  java
  • python读取MNIST image数据

    Lecun Mnist数据集下载

    import numpy as np
    import struct
    
    def loadImageSet(which=0):
        print "load image set"
        binfile=None
        if which==0:
            binfile = open("..//dataset//train-images-idx3-ubyte", 'rb')
        else:
            binfile=  open("..//dataset//t10k-images-idx3-ubyte", 'rb')
        buffers = binfile.read()
    
        head = struct.unpack_from('>IIII' , buffers ,0)
        print "head,",head
    
        offset=struct.calcsize('>IIII')
        imgNum=head[1]
        width=head[2]
        height=head[3]
        #[60000]*28*28
        bits=imgNum*width*height
        bitsString='>'+str(bits)+'B' #like '>47040000B'
    
        imgs=struct.unpack_from(bitsString,buffers,offset)
    
        binfile.close()
        imgs=np.reshape(imgs,[imgNum,width,height])
        print "load imgs finished"
        return imgs
    
    def loadLabelSet(which=0):
        print "load label set"
        binfile=None
        if which==0:
            binfile = open("..//dataset//train-labels-idx1-ubyte", 'rb')
        else:
            binfile=  open("..//dataset//t10k-labels-idx1-ubyte", 'rb')
        buffers = binfile.read()
    
        head = struct.unpack_from('>II' , buffers ,0)
        print "head,",head
        imgNum=head[1]
    
        offset = struct.calcsize('>II')
        numString='>'+str(imgNum)+"B"
        labels= struct.unpack_from(numString , buffers , offset)
        binfile.close()
        labels=np.reshape(labels,[imgNum,1])
    
        #print labels
        print 'load label finished'
        return labels
    
    if __name__=="__main__":
        imgs=loadImageSet()
        #import PlotUtil as pu
        #pu.showImgMatrix(imgs[0])
        loadLabelSet()

    及方便训练的reader

    import numpy as np
    import struct
    import gzip
    import cPickle
    
    class MnistReader():
    
        def __init__(self,mnist_path,data_dim=1,one_hot=True):
            '''
            mnist_path: the path of mnist.pkl.gz
            data_dim=1 [N,784]
            data_dim=3 [N,28,28,1]
            one_hot: one hot encoding(like: [0,1,0,0,0,0,0,0,0,0]) if true
            '''
            self.mnist_path=mnist_path
            self.data_dim=data_dim
            self.one_hot=one_hot
            self.load_minist(mnist_path)
    
            self.train_datalabel=zip(self.train_x,self.train_y)
            self.valid_datalabel=zip(self.valid_x,self.valid_y)
    
            self.batch_offset_train=0
    
        def next_batch_train(self,batch_size):
            '''
            return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
                   and list of labels with shape [N] or [N,10] dependents on self.one_hot
            '''
            if self.batch_offset_train<len(self.train_datalabel)//batch_size:
                imgs=list();labels=list()
                for d,l in self.train_datalabel[self.batch_offset_train:self.batch_offset_train+batch_size]:
                    if self.data_dim==3:
                        d=np.reshape(d, [28,28,1])
                    imgs.append(d)
                    if self.one_hot:
                        a=np.zeros(10)
                        a[l]=1
                        labels.append(l)
                    else:
                        labels.append(l)
                self.batch_offset_train+=1
                return imgs,labels
            else:
                self.batch_offset_train=0
                np.random.shuffle(self.train_datalabel)
                return self.next_batch_train(batch_size)
    
        def next_batch_val(self,batch_size):
            '''
            return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
                   and list of labels with shape [N,1] or [N,10] dependents on self.one_hot
            '''
            np.random.shuffle(self.valid_datalabel)
            imgs=list();labels=list()
            for d,l in self.train_datalabel[0:batch_size]:
                if self.data_dim==3:
                    d=np.reshape(d, [28,28,1])
                imgs.append(d)
                if self.one_hot:
                    a=np.zeros(10)
                    a[l]=1
                    labels.append(l)
                else:
                    labels.append(l)
            return imgs,labels
    
        def load_minist(self,dataset):
            print "load dataset"
            f = gzip.open(dataset, 'rb')
            train_set, valid_set, test_set = cPickle.load(f)
            f.close()
            self.train_x,self.train_y=train_set
            self.valid_x,self.valid_y=valid_set
            self.test_x , self.test_y=test_set
            print "train image,label shape:",self.train_x.shape,self.train_y.shape
            print "valid image,label shape:",self.valid_x.shape,self.valid_y.shape
            print "test  image,label shape:",self.test_x.shape,self.test_y.shape
            print "load dataset end"
    
    if __name__=="__main__":
        mnist=MnistReader('../dataset/mnist.pkl.gz',data_dim=3)
        data,label=mnist.next_batch_train(batch_size=1)
        print data
        print label
    

    第三种加载方式需要 gzip和struct

    import gzip, struct
    
    def _read(image,label):
        minist_dir = 'your_dir/'
        with gzip.open(minist_dir+label) as flbl:
            magic, num = struct.unpack(">II", flbl.read(8))
            label = np.fromstring(flbl.read(), dtype=np.int8)
        with gzip.open(minist_dir+image, 'rb') as fimg:
            magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
            image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
        return image,label
    
    def get_data():
        train_img,train_label = _read(
                'train-images-idx3-ubyte.gz', 
                'train-labels-idx1-ubyte.gz')
        test_img,test_label = _read(
                't10k-images-idx3-ubyte.gz', 
                't10k-labels-idx1-ubyte.gz')
        return [train_img,train_label,test_img,test_label]
  • 相关阅读:
    RAID卡 BBU Learn Cycle周期的影响
    Linux下查看Raid磁盘阵列信息的方法
    ROS导航包的介绍
    ROS源码解读(二)--全局路径规划
    ROS源码解读(一)--局部路径规划
    VS运行release版本正常,直接执行exe文件会出现问题
    IFM设备 Linux方面资料
    Map-making Robots: A Review of the Occupancy Grid Map Algorithm
    Eigen 介绍及简单使用
    绘制二维障碍栅格地图
  • 原文地址:https://www.cnblogs.com/judejie/p/9143974.html
Copyright © 2011-2022 走看看