zoukankan      html  css  js  c++  java
  • 机器学习实战K-近邻算法

    今天开始学习机器学习,第一章是K-近邻算法,有不对的地方请指正
    大概总结一下近邻算法写分类器步骤:
    1. 计算测试数据与已知数据的特征值的距离,离得越近越相似
    2. 取距离最近的K个已知数据的所属分类
    3. 最后统计K个值的分类分别出现的概率,返回最多的一个属性,即为测试数据的所属分类
    4. 至于怎么把文本转换成numpy的类型,需要学习numpy模块的相关知识,附上
    numpy学习连接 http://old.sebug.net/paper/books/scipydoc/numpy_intro.html

    #-*- coding:utf-8 *-*-
    
    from numpy import *
    import operator   #计算模块
    import matplotlib
    import matplotlib.pyplot as plt
    import time
    import random
    from mpl_toolkits.mplot3d import Axes3D
    from os import listdir
    import time
    
    
    def createDataSet():
        group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
        labels = ['A','A','B','B']
        return group,labels
    
    
    #A,B分类
    def classify0(inX, dataSet, labels, k):
        dataSetSize = dataSet.shape[0]
        diffMat = tile(inX,(dataSetSize,1)) - dataSet #tile函数把inx复制datasetsize行1列
        sqDiffMat = diffMat**2
        #print "sqDiffMat : ",sqDiffMat
        sqDistance = sqDiffMat.sum(axis = 1)
        distance = sqDistance**0.5
        #print "distance : ",distance
        sortedDistIndicies = distance.argsort()  #返回从小到大的元素的下标,比如[1 3 2 4].argsort()返回[0 2 1 3]
        #print "****",sortedDistIndicies
        classCount = {}
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]   #统计各个现有值所属的特征向量
            #print sortedDistIndicies[i],voteIlabel
            classCount[voteIlabel] = classCount.get(voteIlabel,0)+1  #统计各个特征向量出现的次数
        sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True) 
        #operator.itemgetter()从小到大排序
        #print "sortedClassCount : ",sortedClassCount
        return sortedClassCount[0][0]
    
    group,labels =  createDataSet()
    
    #print classify0([0,0], group, labels, 3)
    
    # # a = [('b',2),('a',1),('c',0)]
    # a=[('b',2),('a',2),('a',1),('c',0)]
    # b = sorted(a,key =  operator.itemgetter(0)) #优先根据第一个元素排序
    # print b
    # b = sorted(a,key =  operator.itemgetter(1)) #优先根据第二个元素排序
    # print b
    # b = sorted(a,key =  operator.itemgetter(1,0)) #优先根据第二个元素排序,当第二个元素相等的情况下根据第一个元素排序
    # print b
    
    #解析数据
    def file2matrix(filename):
        with open(filename) as f:
            lines = f.readlines()
            matrixNumber = len(lines)
            print 'the all lines is :',matrixNumber
            #matrix = zeros((matrixNumber,3),dtype = 'int') #生成空的n行3列的矩阵
            matrix = zeros((matrixNumber,2))
            vector = []
            index = 0     #矩阵索引
            for line in lines:
                line = line.strip()
                data = line.split("	")
                matrix[index:] = data[0:2]   #把提取出来的复制到矩阵里面
                vector.append(int((data[-1])))  #最后一个特征值作为特征向量
                index+=1
            return matrix,vector
    
    
    #生成文本数据
    def createdata(filename):
        with open(filename,'w') as f:
            for i in range(1000):
                r1 = int(random.random()*1000)
                r2 = 0
                if(0<=r1<=200):
                    r2 = 1
                if(200<r1<=400):
                    r2 = 2
                if(400<r1<=600):
                    r2 = 3
                if(600<r1<=800):
                    r2 = 4
                if(800<r1<=1000):
                    r2 = 5
                r1 = str(r1)
                r2 = str(r2)
                #r2 = str(int(random.random()*10))
                r3 = str(int(random.random()*10))
                f.writelines(r3+'	'+r1+'	'+r2+'
    ')
    
    #createdata(r'D:	est_packagesknntest.txt')
    
    '''
    datat,labels = file2matrix(r'D:	est_packagesknntest.txt')
    print datat
    # print datat[:,1] #纵向的第二列
    # print datat[:][1] #横向的第二列
    print labels
    fig = plt.figure()  #生成容器
    plt.title('favorite table data')
    ax = fig.add_subplot(1,1,1,projection='3d') #3D模型
    ax.scatter(datat[:,0],datat[:,1],datat[:,2],array(labels),array(labels),array(labels))  #使用datat的第二列和第三列作为X轴和Y轴的值
    ax.legend()
    plt.show()
    
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1) #把容器划分为1行1列,图像画在第一格,背景颜色为axisbg = ‘’
    ax.scatter(datat[:,1],datat[:,2],array(labels),array(labels))  #使用datat的第二列和第三列作为X轴和Y轴的值
    #ax.grid(True) #是否显示网格
    # plt.show()
    plt.show()
    '''
    
    #归一化,(old-min)/(max-min)
    def autoNormal(dataSet):
        maxVals = dataSet.max(0)  #纵向找到每一个样本的最大特征值
        minVals = dataSet.min(0)
        ranges = maxVals - minVals #计算差值
        normalValue = zeros(shape(dataSet))
        m = dataSet.shape[0]
        normalValue = dataSet - tile(minVals,(m,1))   #计算(old-min)
        normalValue = normalValue/tile(ranges,(m,1))
        return normalValue,ranges,minVals
    
    
    #归一化特征值之后
    datat,labels = file2matrix(r'D:	est_packagesknntest.txt')
    normalValue,ranges,minVals = autoNormal(datat)
    print normalValue
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1) #把容器划分为1行1列,图像画在第一格,背景颜色为axisbg = ‘’
    ax.scatter(normalValue[:,0],normalValue[:,1],array(labels),array(labels))  #使用datat的第二列和第三列作为X轴和Y轴的值
    #ax.grid(True) #是否显示网格
    # plt.show()
    plt.show()
    
    #约会网站测试函数
    def datinggTest():
        datat,labels = file2matrix(r'D:	est_packagesknntest.txt')
        normal,ranges,minvals = autoNormal(datat)
        testData = 0.5  #10%用来测试,90%用来训练
        testNumber = normal.shape[0]  #总行数
        numberTestValues = int(testNumber*testData)  #测试行数
        error = 0.0
        for i in range(numberTestValues):
            labelValue = classify0(normal[i,:], normal[numberTestValues:testNumber,:], labels[numberTestValues:testNumber], 3)
            if (labelValue != labels[i]):
                error+=1.0
                print "this time is error the error is %s, the right is %s"%(labelValue,labels[i])
            else:
                print "all right ,the number is %s, the right is %s"%(labelValue,labels[i])
        error_result = ((error/float(numberTestValues)))
        print "your error_result is %s"%(error_result)
        print 'error is :',error
    datinggTest()
    
    #把二进制文件转化为np.array
    def img2Vector(filename):
        with open(filename) as f:
            vector = zeros((1,1024))
            for i in range(32):
                line = f.readline()
                for j in range(32):
                    vector[0,32*i+j] = line[j]
        return vector
    vector = img2Vector(r'D:	est_packages	rainingDigits_0.txt')
    print vector[0,11:17]
    
    #手写数字识别系统测试代码
    def handwritingClassTest():
        startTime = time.ctime()
        handLabels = []
        trainFile = listdir(r'D:	est_packages	rainingDigits')
        m = len(trainFile)
        trainMat = zeros((m,1024))
        for i in range(m):
            fileName = trainFile[i]
            file = fileName.split('.')[0]
            classNumber = file.split('_')[0]
            handLabels.append(classNumber)
            trainMat[i,:] = img2Vector(r'D:	est_packages	rainingDigits\%s'%fileName) 
        testFiles = listdir(r'D:	est_packages	estDigits')
        nTest = len(testFiles)
        error = 0.0
        for i in range(nTest):
            fileName = testFiles[i]
            file = fileName.split('.')[0]
            classNumber = file.split('_')[0]
            testMat = img2Vector(r'D:	est_packages	estDigits\%s'%fileName) 
            testLabels = classify0(testMat, trainMat, handLabels, 3)
            if (testLabels != classNumber):
                error+=1.0
                print 'error , error number is %s, the right number is %s'%(testLabels,classNumber)
            else:
                print 'right'
        error = error/float(nTest)
        stopTime = time.ctime()
        print 'all right ,the error_result is %s'%(error)
        print 'the process start at %s'%(startTime)
        print 'the process stop at %s'%(stopTime)
    
    handwritingClassTest()
    
    欢迎来邮件交流:lq65535@163.com
  • 相关阅读:
    day26
    day 25
    java.io.IOException: java.net.ConnectException: Call From master/192.168.58.128 to master:10020 failed on connection exception: java.net.ConnectException: 拒绝连接;
    疫情可视化系统
    使用eclipse创建spring cloud的eureka客户端和eureka服务端
    连接虚拟机的hive时进程自动杀死
    在Ubuntu18.04的Docker中安装Oracle镜像及简单使用
    Ubuntu16.04 上Docker 中安装SQL Server 2017
    docker
    Docker镜像报错
  • 原文地址:https://www.cnblogs.com/lq1024/p/7593639.html
Copyright © 2011-2022 走看看