zoukankan      html  css  js  c++  java
  • 机器学习实战kNN之手写识别

    kNN算法算是机器学习入门级绝佳的素材。书上是这样诠释的:“存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都有标签,即我们知道样本集中每一条数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征比较,算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前K个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类”。

    优点:精度高、对异常值不敏感、无数据输入假定。

    缺点:计算复杂度高、空间复杂度高。

    适用数据范围:数值型或标称型。

    算法的python实现:

    def kNN(data, dataSet, dataLabel, k=3, similarity=sim_distance):  
    	scores = [(sim_distance(data, dataSet[i]), dataLabel[i]) for i in range(len(dataSet))]
    	sortedScore = sorted(scores, key=lambda d: d[0], reverse=False) 
    	scores = sortedScore[0:k]
    	
    	classCount = {} 
    	for score in scores:
    		classCount[score[1]] = classCount.get(score[1], 0) + 1
    	
    	sortedClassCount = sorted(classCount.items(), key=lambda d: d[1], reverse=True)
    	return sortedClassCount[0][0]
    		

    下面分为几步骤来学习这个算法:

    (1)准备数据

    (2)测试算法

    先介绍一个这个手写识别系统,简单起见,该系统只能识别数字0---9,需要识别的数字已经使用图形处理软件,处理成具有相同色彩和大小:32*32像素的黑白照片。目录trainingDigits中包含了大约2000个训练样本,目录testDigits中大约有900个测试样本。

    第一步,准备数据:将图片数据转换成测试向量。这一步就是把我们32*32的二进制图像矩阵转换成1*1024的向量。

    def img2vector(filename):
    	vec = []
    	file = open(filename)
    	for i in range(32):
    		line = file.readline()
    		for j in range(32):
    			vec.append(int(line[j]))
    	return vec


    第二步,测试算法准确率,我们用 trainingDigits目录下的样本做训练,来测试testDigits目录下的样本,来计算准确率。

    def test():
    	trainData, trainLabel = [], []
    	trainFileList = os.listdir('digits/trainingDigits/')
    	for filename in trainFileList:
    		trainData.append(img2vector('digits/trainingDigits/%s' % filename))
    		trainLabel.append(int(filename.split('_')[0]))
    		
    	succCnt, failCnt = 0, 0
    	testFileList = os.listdir('digits/testDigits')
    	for filename in testFileList:
    		data = img2vector('digits/testDigits/%s' % filename)
    		num = kNN(data, trainData, trainLabel)
    		if num == int(filename.split('_')[0]):
    			succCnt += 1
    			print 'succ'
    		else:
    			failCnt += 1
    			print 'fail'
    			
    	print "error rate is : %f " % (failCnt/float(failCnt+succCnt))

    我这里测试,K取默认值3,错误率是0.013742,


    不会上传文件,所以把代码贴在下面,测试数据在 http://download.csdn.net/detail/wyb_009/5649337第二章下面

    import os, math
    def sim_distance(a, b):
    	sum_of_squares = sum([pow(a[i]-b[i], 2) for i in range(len(a))])  
    	return sum_of_squares 
    
    def kNN(data, dataSet, dataLabel, k=3, similarity=sim_distance):  
    	scores = [(sim_distance(data, dataSet[i]), dataLabel[i]) for i in range(len(dataSet))]
    	sortedScore = sorted(scores, key=lambda d: d[0], reverse=False) 
    	scores = sortedScore[0:k]
    	
    	classCount = {} 
    	for score in scores:
    		classCount[score[1]] = classCount.get(score[1], 0) + 1
    	
    	sortedClassCount = sorted(classCount.items(), key=lambda d: d[1], reverse=True)
    	return sortedClassCount[0][0]
    		
    def img2vector(filename):
    	vec = []
    	file = open(filename)
    	for i in range(32):
    		line = file.readline()
    		for j in range(32):
    			vec.append(int(line[j]))
    	return vec
    		
    def test():
    	trainData, trainLabel = [], []
    	trainFileList = os.listdir('digits/trainingDigits/')
    	for filename in trainFileList:
    		trainData.append(img2vector('digits/trainingDigits/%s' % filename))
    		trainLabel.append(int(filename.split('_')[0]))
    	print "load train data ok"
    	
    	succCnt, failCnt = 0, 0
    	testFileList = os.listdir('digits/testDigits')
    	for filename in testFileList:
    		data = img2vector('digits/testDigits/%s' % filename)
    		num = kNN(data, trainData, trainLabel)
    		if num == int(filename.split('_')[0]):
    			succCnt += 1
    			print 'succ'
    		else:
    			failCnt += 1
    			print 'fail: kNN get %ld, real is %ls' %(num, int(filename.split('_')[0]))
    			
    	print "error rate is : %f " % (failCnt/float(failCnt+succCnt))
    	
    if __name__ == "__main__":
    	test()







  • 相关阅读:
    洛谷 P1278 单词游戏 【状压dp】
    洛谷 P1854 花店橱窗布置 【dp】
    洛谷 P2258 子矩阵
    洛谷 P3102 [USACO14FEB]秘密代码Secret Code 【区间dp】
    洛谷U14200 Changing 题解 【杨辉三角】
    洛谷P3933 Chtholly Nota Seniorious 【二分 + 贪心 + 矩阵旋转】
    P3932 浮游大陆的68号岛 【线段树】
    洛谷P1273 有线电视网 【树上分组背包】
    NOI2013 矩阵游戏 【数论】
    洛谷P1268 树的重量 【构造 + 枚举】
  • 原文地址:https://www.cnblogs.com/jiangu66/p/3157146.html
Copyright © 2011-2022 走看看