zoukankan      html  css  js  c++  java
  • k-近邻算法学习笔记

    刚开始研读《machine learning in action》这本书,介绍的第一个算法就是k-近邻算法。

    机器学习算法可分为监督学习和无监督学习,其中监督学习解决的是问题是分类和回归这两类问题,而无监督学习由于没有目标值和类别信息,将数据集合进行聚类。无监督学习尚不了解,以前课题用到了神经网络也并未对无监督学习方法涉及。

    分类问题即给定一个实例数据,将其划分至合适的类别中;回归问题解决的是预测值,最简单的回归问题应该是物理实验课上用的一元二次回归了。

    对于分类问题,k-近邻算法是一种简单有效的算法,其思路特别简单。假定存在一个已知样本集S,S中每个样本si对应有一个类别cj,其中类别集合C是有限的。那么给定一个待分类数据d,可由如下方法给出:

    1. 计算d与S中每个si之间的欧氏距离;
    2. 对所有的距离进行升序排列;
    3. 取距离最近的k个样本集s1~sk,其对应的类别为c1~ck;
    4. c1~ck中出现频率最高的类别就是d的类别。

    算法实现起来也很简单,python版本如下:

    #-*-coding: utf-8 -*-
    
    from numpy import * 
    import operator
    import matplotlib
    import matplotlib.pyplot as pyplot
    from dircache import listdir
    
    def create_data_set():
        group   = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
        labels  = ['A', 'A', 'B', 'B']
        return group, labels
    
    def classify0(inX, data_set, labels, k):
        data_set_size = data_set.shape[0]   # number of samples
        diff_mat = tile(inX, (data_set_size, 1)) - data_set # tile: expand to 1 cols and data_set_size rows
        diff_mat2 = diff_mat ** 2
        distances2 = diff_mat2.sum(axis = 1)
        distances = distances2 ** 0.5
        sorted_dist_index = distances.argsort()
        class_count = {}
        for i in xrange(k):
            vote_ilabel = labels[sorted_dist_index[i]]
            class_count[vote_ilabel] = class_count.get(vote_ilabel, 0) + 1
        sorted_class_count = sorted(class_count.iteritems(), key = operator.itemgetter(1), reverse = True)
        return sorted_class_count[0][0]
    
    def test_classify0():
        group, labels = create_data_set()
        res1 = classify0([0, 0], group, labels, 3)
        print '分类结果', res1
    
    if __name__ == '__main__':
        test_classify0()

    其中classify0就是k-近邻算法的实现。这里用到了numpy包。

    测试结果

    image

    对于《machine learning in action》中给的几个例子我也重新做了一遍,其实大同小异,大部分工作都是如何将外部的数据导入:)

    约会网站示例

    #-*-coding: utf-8 –*- 
     
    from numpy import * 
    import operator
    import matplotlib
    import matplotlib.pyplot as pyplot
    from dircache import listdir
    
    def file2matrix(filename):
        fr = open(filename)
        read_lines = fr.readlines()
        sample_count = len(read_lines)
        print '%d lines in "%s"' % (sample_count, filename)
        sample_matrix = zeros((sample_count, 3))    # 3个特征
        label_vector = []
        isample = 0
        for line in read_lines:
            line = line.strip()
            one_sample_list = line.split('	')
            sample_matrix[isample, :] = [double(item) for item in one_sample_list[0 : 3]]    
            label_vector.append(int(one_sample_list[-1]))   # 每行最后一个值为类别
            isample += 1
        return sample_matrix, label_vector
     
    def auto_normalize(data_set):
        min_val = data_set.min(0)
        max_val = data_set.max(0)
        ranges = max_val - min_val
        m = data_set.shape[0]
        norm_set = data_set - tile(min_val, (m, 1))
        norm_set = norm_set / tile(ranges, (m, 1))
        return norm_set, ranges, min_val
     
    def test_dating_classify():
        dating_matrix, dating_label = file2matrix('datingTestSet2.txt')
    #     fig = pyplot.figure()
    #     ax = fig.add_subplot(111)
    #     ax.scatter(dating_matrix[:, 0], dating_matrix[:, 1], 15.0 * array(dating_label), 15 * array(dating_label))
    #     pyplot.show()
        norm_matrix, _, _ = auto_normalize(dating_matrix)
        verify_ratio = 0.1
        samples_count = norm_matrix.shape[0]
        verify_count = int(verify_ratio * samples_count)
        error_count = 0.0
        for i in xrange(verify_count):
            classify_result = classify0(norm_matrix[i, :], norm_matrix[verify_count : samples_count, :], 
                                        dating_label[verify_count : samples_count], 9)
            print '分类器识别为%d,真实类别为%d' % (classify_result, dating_label[i])
            if (classify_result != dating_label[i]):
                error_count += 1
        print '分类错误率为:%.2f' % (error_count / float(verify_count))
    
    if __name__ == '__main__':
        test_dating_classify()

    测试结果

    image

    手写识别实例

    #-*-coding: utf-8 -*-
    
    from numpy import * 
    import operator
    import matplotlib
    import matplotlib.pyplot as pyplot
    from dircache import listdir
    
    def img2vector(filename):
        img_vector = zeros((1, 1024))
        fr = open(filename)
        for i in xrange(32):
            line_str = fr.readline()
            line_str = line_str.strip()
            for j in xrange(32):
                img_vector[0, 32 * i + j] = int(line_str[j])
        return img_vector
        
    def test_handwritting_classify():
        handwritting_labels = []
        training_files = listdir('trainingDigits')
        samples_count = len(training_files)
        training_matrix = zeros((samples_count, 1024))
        # construct training matrix
        for i in xrange(samples_count):
            file_name_str = training_files[i]
            file_str = file_name_str.split('.')[0]
            label_str = int(file_str.split('_')[0])
            handwritting_labels.append(label_str)
            training_matrix[i, :] = img2vector('trainingDigits/%s' % file_name_str)
        # test
        test_files = listdir('testDigits')
        error_count = 0
        tests_count = len(test_files)
        for i in xrange(tests_count):
            file_name_str = test_files[i]
            file_str = file_name_str.split('.')[0]
            label_str = int(file_str.split('_')[0])
            vector_under_test = img2vector('testDigits/%s' % file_name_str)
            classify_result = classify0(vector_under_test, training_matrix, handwritting_labels, 3)
            print '手写识别为%d, 实际为%d' % (classify_result, label_str)
            if (classify_result != label_str):
                error_count += 1
        print '手写识别错误共计%d, 错误率%.2f' % (error_count, error_count / float(tests_count))
    
    if __name__ == '__main__':
        test_handwritting_classify()

     

    测试结果

    image

    总结

    从原理上讲,k-近邻算法是精确有效的,也符合人的分类习惯,说白了,离谁最近就是谁。

    但从使用的情况看,k-近邻算法运行速度非常慢(计算复杂度高),存储空间要求有很大(空间复杂度高)。

  • 相关阅读:
    sql server 存储过程分隔split
    sql server 存储过程、事务,增删改
    jquery ajax 参数可以序列化
    Pycharm连接gitlab
    gitlab的搭建和linux客户端的连接
    jenkins的搭建和使用
    svn服务器及客户端安装使用
    python2 和python3共存下问题
    Codecademy For Python学习笔记
    类编写的细节
  • 原文地址:https://www.cnblogs.com/robert-cai/p/3466131.html
Copyright © 2011-2022 走看看