zoukankan      html  css  js  c++  java
  • KNN 实现mnist数据集分类

    一 数据预处理

    训练数据集和验证数据集分别为train.csv和test.csv。数据集下载地址:http://pan.baidu.com/s/1eQyIvZG

    要分别对训练数据集和验证数据集进行分析,分析其内部数据的特征,下面分别对两个数据集进行处理:

    1.1 训练数据集处理

    train.csv 里面结构为42001 * 785。其中第一行为文字说明,应该去掉,其余每一行均表示一个图像,大小为28*28,共784个像素值;第一列为类标签,每一个标签表示一个图像所代表的数字,范围为0-9;所以处理的步骤为:把所有数据存入列表中;删除第一行,得到42000*785;分离开第一列和剩余数据,分别得到42000*1和42000*784两个矩阵。

    具体代码如下:

    def loadtraindata(trainfile):#传参为所读文件名
        l = list()#创建序列,要保存文件内容
        with open(trainfile,'rb') as filename:        
            lines = csv.reader(filename)
            for line in lines:
                l.append(line)
            del l[0]#删除第一行
            l = np.array(l)#转换为数组
            label = l[:,0]#取数组内所有行第一列元素
            data = l[:,1:]#取数组内所有行,从第二列至最后列元素
            label = np.int32(label)#int32为numpy 内部函数,进行数据类型转换
            data = nomalizing(np.int32(data))#nomalizing 为自定义函数,进行数据标准化
            return data,label
            

    标准化函数代码如下:

    def nomalizing(array):
        m,n = np.shape(array)#shape函数为得到数组的各个维度
        for i in xrange(m):
            for j in xrange(n):
                if array[i,j] != 0:
                    array[i,j] = 1
        return array

    二 KNN实现分类

    现已知训练集中有42000组元素和对应每组的类别,现给出一个未知类别的一组元素,要求预测其类别。KNN的做法是:找到与该组元素最近的k组;找到这k组元素里类别相同数最多的一个类别;认为该类别就是该未知类别元素的类别;

    2.1 KNN具体代码如下:

    第一种代码:

    找到与该组元素最近的k组:这里面涉及到几个点,1、如何判断最近?有欧几里得距离,曼哈顿距离。2、如何找到k组?即既要找到k个距离最小的组,同时要知道这些组的索引;因为这些组的索引用于知道他们的类别。所以我采用字典这个数据结构,既储存距离,同时存储索引。

    def KNN(X,traindata,trainlabel,k):#X为未知类别数据,k为最近邻个数
        find = dict()
        listkey = 0
        aa = [0,1,2,3,4,5,6,7,8,9]
        bb = [0,0,0,0,0,0,0,0,0,0]    
        m,n = np.shape(traindata)           
        for i in xrange(30000):#训练数据集中前三万组用于训练
            sum = 0
            for j in xrange(n):
                sum += math.fabs(X[j]-traindata[i,j])
            if i < k:
                find[i]=sum
            else:
                for key in find.keys():
                    if sum < find[key]:
                        del find[key]
                        find[i] = sum
                        break
        for key in find.keys():#找到这k个点类别相同数最多的类别
            for i in xrange(10):
                if trainlabel[key] == aa[i]:
                    bb[i] += 1
                                
        for i in xrange(10):
            if bb[i] == max(bb):
                listkey = i
                break
                
        return aa[listkey]#返回该未知类型数据的预测类别

    第二种代码:

    由于训练集数据量较大,且都是数组之间的操作,可以使用numpy库中array和mat函数进行处理,提高计算速度。

    def KNN(X,traindata,trainlabel,k):
        X = np.mat(X)#这三行实现的是将序列转换成矩阵,mat是numpy里转换成矩阵的函数
        traindata = np.mat(traindata)
        trainlabel = np.mat(trainlabel)
        trainsize = traindata.shape[0]#得到traindata的第一维大小   
        distance = np.sum(np.array(np.tile(X,(trainsize,1))-traindata)**2,1)
        distancesort = distance.argsort()
        
        countdict = dict()
        for i in xrange(k):
            Xlabel= trainlabel[0,distancesort[i]]#distancesort存储的是训练数据集的索引,前k个为距离最小的k个点的索引,通过trainlabel得到k个点的类别
            countdict[Xlabel] = countdict.get(Xlabel,0) + 1#通过字典,存储k个点上每个类别和对应的类别数量
        countlist = sorted(countdict.iteritems(),key=lambda x:x[1],reverse = True)#对字典的值,按降序排列,得到降序排列的存储各元祖的序列
        return countlist[0][0]#其序列的第一个元组为类别数最多的元组,第一个元素为其类别。将该类别赋值给未知元组

    2.2 计算召回率

    def compute(traindata,trainlabel,k):
        error = 0
        for i in xrange(41800,42000):#用200组进行验证
            X = traindata[i]    
            if KNN(X,traindata,trainlabel,k) != trainlabel[i]:
                error += 1
        
        return 1 - error / 200.00

    三 补充:(涉及到的Python和Numpy语法细节)

    3.1 Numpy

    Numpy(一个用Python实现的科学计算包),包括:1、一个强大的N维数组对象Array;2、用于整合C/C++和Fortran代码的工具包;3、实用的线性代数、傅里叶变换和随机数生成函数。

    3.1.1 生成数组

    创建数组采用array函数,它接受一切序列型的对象,产生一个新的含有传入数据的Numpy数组。

    import numpy as np
    
    a = [[2,3,4,5,6,7],[1,0,2,6,5,3]]
    aa = np.array(a)
    
    b = [1,3,5]
    bb = np.array(b)
    
    np.zeros(3)
    np.zeros((4,5))
    
    np.ones(3)
    np.ones((4,5))

    3.1.2 索引与切片

    import numpy as np
    
    a = [[2,3,4,5,6,7],[1,0,2,6,5,3],[2,4,5,6,4,3]]
    aa = np.array(a)
    
    aa[1]#索引
    aa[1,2]
    
    aa[:]#切片,没有逗号默认只行切片
    aa[:,:]
    aa[1:,:1]#行是从第二行到最后,列是从开始到第二行(不包括)
    
    aa[1:] = 2 #切片本质上不是复制,所以对它的修改会影响原数组
    bb = aa[1:].copy() #复制切片,再修改b ,不会影响原数组

    3.1.3 数组/矩阵转置

    import numpy as np
    
    a = [[2,3,4,5,6,7],[1,0,2,6,5,3],[2,4,5,6,4,3]]
    aa = np.mat(a)
    
    aa.T#矩阵转置,只有求T时aa可以是数组,其他都必须是矩阵
    aa.H#矩阵共轭转置
    aa.I#矩阵的逆矩阵
    aa.A#矩阵的二维视图

    3.1.4 数组与矩阵

     Matrix 类型继承于ndarray类型,因此含有ndarray的所有属性和方法。Matrix类型和ndarray类型常用的不同有:

    a . Matrix对象是二维的。例子中mat之后bb为二维的矩阵

    import numpy as np
    
    b = [3,5,74,6]
    bb = np.mat(b)
    print bb[0,3]

    b . Matrix类型的乘法覆盖了array的乘法,使用的是矩阵的乘法运算。

    import numpy as np
    
    b = [3,5,74,6]
    bb = np.array(b)
    cc = np.mat(b)
    
    bb*bb#数组乘法,为元素间相乘,即点乘
    cc*cc.T#矩阵乘法,遵守前一个矩阵的列等于后一个矩阵的行这样的矩阵运算规则

    c . Matrix 类型的幂运算覆盖了array的幂运算。

    import numpy as np
    
    b = [[3,5],[74,6]]
    bb = np.array(b)
    cc = np.mat(b)
    
    print bb**2#数组的幂运算,是对每一个元素进行幂运算,bb不必须是行列相同
    print cc**2#矩阵的幂运算,要求矩阵cc为方阵,然后进行方阵之间的矩阵运算

    d . 矩阵具有转置、共轭转置、逆矩阵等特有属性。

    3.1.5 数组排序

    a .  sorted 方法

    python 的内置函数(built-in functions)

     sorted(...)
        sorted(iterable, cmp=None, key=None, reverse=False) --> new sorted list


    iterable:是可迭代类型;
    cmp:用于比较的函数,比较什么由key决定,有默认值,迭代集合中的一项;
    key:用列表元素的某个属性和函数进行作为关键字,有默认值,迭代集合中的一项;
    reverse:排序规则. reverse = True 或者 reverse = False,有默认值。
    返回值:是一个经过排序的可迭代类型,与iterable一样。

    import numpy as np
    
    b = [[0,3],[2,2],[4,2]]
    
    bb = np.array(b)
    
    print sorted(bb,key=lambda x:x[1],reverse=False)#key指把bb代入x,对x第二维进行比较,reverse=True为降序,sorted排序不影响b
    print sorted(bb,key = lambda x:(x[1],x[0]),reverse = True)#这里的key是先按第二维排序,再按第一维排序

    b. sort方法

    list的内置函数

    sort(...)
     |      L.sort(cmp=None, key=None, reverse=False) -- stable sort *IN PLACE*;
     |      cmp(x, y) -> -1, 0, 1

    b = [[0,3],[1,4],[4,0]]
    
    b.sort(key=lambda x:x[1],reverse=False)
    print b #对b调用sort函数会导致b的变化

    c . argsort 方法

     argsort(a, axis=-1, kind='quicksort', order=None)
        Returns the indices that would sort an array.

    a : 要排序的数组
    axis : int or None, optional
            Axis along which to sort.  The default is -1 (the last axis). If None,
            the flattened array is used.

       axis = 0 按列排序(每列之间排序),axis= 1 按行排序(每行之间排序),默认行排序

    kind : {'quicksort', 'mergesort', 'heapsort'}, optional
            Sorting algorithm.

    import numpy as np
    
    a = [1,3,5]
    aa = np.array(a)
    
    b = [[0,3],[2,2],[4,2]]
    bb = np.array(b)
    
    print np.argsort(aa)
    print np.argsort(bb,0)
    print np.argsort(bb,1)
  • 相关阅读:
    PHP如何判断一个gif图片是否为动画?
    Linux常用系统管理命令(top、free、kill、df)
    Mysql字符串连接函数 CONCAT()与 CONCAT_WS()
    OSChina.net 的 Tomcat 配置 server.xml 参考
    修改Linux默认启动级别或模式
    更改CentOS 6.3 yum源为国内 阿里云源
    PHP session过期机制和配置
    PHP垃圾回收机制防止内存溢出
    memcache与memcached的区别
    【总结】虚拟机VirtualBox各种使用技巧
  • 原文地址:https://www.cnblogs.com/tosouth/p/4737149.html
Copyright © 2011-2022 走看看