zoukankan      html  css  js  c++  java
  • python用K近邻(KNN)算法分类MNIST数据集和Fashion MNIST数据集

    一、KNN算法的介绍

      K最近邻(k-Nearest Neighbor,KNN)分类算法是最简单的机器学习算法之一,理论上比较成熟。KNN算法首先将待分类样本表达成和训练样本一致的特征向量;然后根据距离计算待测试样本和每个训练样本的距离,选择距离最小的K个样本作为近邻样本;最后根据K个近邻样本判断待分类样本的类别。KNN算法的正确选取是分类正确的关键因素之一,而近邻样本是通过计算测试样本与每个训练集样本的距离来选定的,故定义合适的距离是KNN正确分类的前提。

    本文中在上述研究的基础上,将特征属性值对类别判断的重要性视为同样重要,将样本距离重新定义为任意两样本间像素点间的相关距离,并且距离计算使用的是距离。

    二、算法原理

      k-近邻算法(KNN),其工作原理是存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

      收集和准备数据,这里使用的是mnist数据集和fashion mnist数据集,输入样本数据和结构化的输出结果,可以调整k的值,然后运行k-近邻算法判断输入数据分别属于哪个分类,最后计算错误率和准确率。

    KNN算法(k邻近算法分类算法),就是k个最近的邻居的,说的是每个样本都可以用它最接近的k个邻居来代表,核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离使用的是欧氏距离。

    详细实现:将mnist数据集和fashion mnist数据集包括训练集和验证集导入到工程文件中,接着计算验证集和训练集的距离,并从小到达排序得到距离最近的k个邻居,并通过投票得到所属类别最高的类别,并判断该验证集的图片属于该类别,接着讲该类别的标签和验证集的标签进行比对,如果相符合则是正确的,如果不相符合,则是属于出错,最后输出计算出的错误率和准确率。

    三、数据集介绍
      MNIST数据集,训练集60000张图片和标签;测试集有10000张图片和标签。读取28*28图片以后,要将每张图片转换为1*784的向量。
    四、KNN算法实现和结果分析
    代码实现:
    from numpy import *
    import operator
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from os import listdir
    from mpl_toolkits.mplot3d import Axes3D
    import struct

    #读取图片
    def read_image(file_name):
    #先用二进制方式把文件都读进来
    file_handle=open(file_name,"rb") #以二进制打开文档
    file_content=file_handle.read() #读取到缓冲区中

    offset=0
    head = struct.unpack_from('>IIII', file_content, offset) # 取前4个整数,返回一个元组
    offset += struct.calcsize('>IIII')
    imgNum = head[1] #图片数
    rows = head[2] #宽度
    cols = head[3] #高度
    # print(imgNum)
    # print(rows)
    # print(cols)

    #测试读取一个图片是否读取成功
    #im = struct.unpack_from('>784B', file_content, offset)
    #offset += struct.calcsize('>784B')

    images=np.empty((imgNum , 784))#empty,是它所常见的数组内的所有元素均为空,没有实际意义,它是创建数组最快的方法
    image_size=rows*cols#单个图片的大小
    fmt='>' + str(image_size) + 'B'#单个图片的format

    for i in range(imgNum):
    images[i] = np.array(struct.unpack_from(fmt, file_content, offset))
    # images[i] = np.array(struct.unpack_from(fmt, file_content, offset)).reshape((rows, cols))
    offset += struct.calcsize(fmt)
    return images

    '''bits = imgNum * rows * cols # data一共有60000*28*28个像素值
    bitsString = '>' + str(bits) + 'B' # fmt格式:'>47040000B'
    imgs = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
    imgs_array=np.array(imgs).reshape((imgNum,rows*cols)) #最后将读取的数据reshape成 【图片数,图片像素】二维数组
    return imgs_array'''

    #读取标签
    def read_label(file_name):
    file_handle = open(file_name, "rb") # 以二进制打开文档
    file_content = file_handle.read() # 读取到缓冲区中

    head = struct.unpack_from('>II', file_content, 0) # 取前2个整数,返回一个元组
    offset = struct.calcsize('>II')

    labelNum = head[1] # label数
    # print(labelNum)
    bitsString = '>' + str(labelNum) + 'B' # fmt格式:'>47040000B'
    label = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
    return np.array(label)

    #KNN算法
    def KNN(test_data, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]#dataSet.shape[0]表示的是读取矩阵第一维度的长度,代表行数
    # distance1 = tile(test_data, (dataSetSize,1)) - dataSet#欧氏距离计算开始
    # print("dataSetSize:")
    # print(dataSetSize)
    distance1 = tile(test_data, (dataSetSize)).reshape((60000,784))-dataSet#tile函数在行上重复dataSetSizec次,在列上重复1次
    # print("distance1.shape")
    # print(distance1.shape)
    distance2 = distance1**2 #每个元素平方
    distance3 = distance2.sum(axis=1)#矩阵每行相加
    distances4 = distance3**0.5#欧氏距离计算结束
    # print(distances4[53843])
    # print(distances4[38620])
    # print(distances4[16186])
    sortedDistIndicies = distances4.argsort() #返回从小到大排序的索引
    classCount=np.zeros((10), np.int32)#10是代表10个类别
    for i in range(k): #统计前k个数据类的数量
    voteIlabel = labels[sortedDistIndicies[i]]
    classCount[voteIlabel] += 1
    max = 0
    id = 0
    print(classCount.shape[0])
    # print(classCount.shape[1])

    for i in range(classCount.shape[0]):
    if classCount[i] >= max:
    max = classCount[i]
    id = i
    print(id)

    # sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#从大到小按类别数目排序
    return id

    def test_KNN():
    # 文件获取
    #mnist数据集
    # train_image = "F:mnist\train-images-idx3-ubyte"
    # test_image = "F:mnist\t10k-images-idx3-ubyte"
    # train_label = "F:mnist\train-labels-idx1-ubyte"
    # test_label = "F:mnist\t10k-labels-idx1-ubyte"
    #fashion mnist数据集
    train_image = "train-images-idx3-ubyte"
    test_image = "t10k-images-idx3-ubyte"
    train_label = "train-labels-idx1-ubyte"
    test_label = "t10k-labels-idx1-ubyte"
    # 读取数据
    train_x = read_image(train_image) # train_dataSet
    test_x = read_image(test_image) # test_dataSet
    train_y = read_label(train_label) # train_label
    test_y = read_label(test_label) # test_label

    # print(train_x.shape)
    # print(test_x.shape)
    # print(train_y.shape)
    # print(test_y.shape)
    # plt.imshow(train_x[0])
    # plt.show()

    testRatio = 1 # 取数据集的前0.1为测试数据,这个参数比重可以改变
    train_row = train_x.shape[0] # 数据集的行数,即数据集的总的样本数
    test_row=test_x.shape[0]
    testNum = int(test_row * testRatio)
    errorCount = 0 # 判断错误的个数
    for i in range(testNum):
    result = KNN(test_x[i], train_x, train_y, 30)
    # print('返回的结果是: %s, 真实结果是: %s' % (result, train_y[i]))

    print(result, test_y[i])
    if result != test_y[i]:
    errorCount += 1.0# 如果mnist验证集的标签和本身标签不一样,则出错
    error_rate = errorCount / float(testNum) # 计算出错率
    acc = 1.0 - error_rate
    print(errorCount)
    print(" the total number of errors is: %d" % errorCount)
    print(" the total error rate is: %f" % (error_rate))
    print(" the total accuracy rate is: %f" % (acc))

    if __name__ == "__main__":
    test_KNN()#test()函数中调用了读取数据集的函数,并调用分类函数对数据集进行分类,最后对分类情况进行计算
    结果分析:

    输入:mnist数据集或者fashion mnist数据集

    输出:出错率和准确率

    Mnist数据集:

    取k=30,验证集是50个的时候,准确率是1;

    取k=30,验证集是500个的时候,准确率是0.98;

    取k=30,验证集是10000个的时候,准确率是0.84。

    Fashion Mnist数据集

    K=30,验证集是10000的时候,一共的出错个数是1666,准确率是0.8334。

    本文中的数据集采用KNN算法得到了较高的准确率,但是本文中考虑特征属性值对类别判断的重要性一样,改进算法时应该考虑特征属性值对类别判断的重要性不同,两样本间属性的相关距离可以用来度量属性值对类别的重要性,相关距离熵越小,两样本的相似程度越大,类可信度越大;此外本文中应该对不同取值的k进行分别的试验,得到使准确率较高的k,同时在实验多个k的时候,可以采用多线程进行跑实验,缩短时间。



    一生有所追!
  • 相关阅读:
    JS BOM对象 History对象 Location对象
    JS 字符串对象 数组对象 函数对象 函数作用域
    JS 引入方式 基本数据类型 运算符 控制语句 循环 异常
    Pycharm Html CSS JS 快捷方式创建元素
    CSS 内外边距 float positio属性
    CSS 颜色 字体 背景 文本 边框 列表 display属性
    【Android】RxJava的使用(三)转换——map、flatMap
    【Android】RxJava的使用(二)Action
    【Android】RxJava的使用(一)基本用法
    【Android】Retrofit 2.0 的使用
  • 原文地址:https://www.cnblogs.com/BlueBlue-Sky/p/9383120.html
Copyright © 2011-2022 走看看