zoukankan      html  css  js  c++  java
  • KNN算法

    核心思想:简单的统计距离目标点最近的K个节点里数目最多的标签赋予目标点(给定已经标注好类别的训练集,然后对测试集中的样本进行分类);

    步骤

     

    总结:距离度量、K值选取、分类决策规则

    (1)距离度量

     

    (2)K值选取

    涉及到近似误差和估计误差:

    (1)K值越大,近似误差越大,而估计误差越小;

    (1)K值越小,近似误差越小,而估计误差越大;

    一般,这么理解估计误差和近似误差:

    近似误差可以理解为模型估计值和实际值之间的差距;(针对训练集)

    估计误差可以理解为模型的估计参数和实际参数的差距;(针对测试集)

    比如,这里对于KNN算法,K值越小会导致特征空间被划分成更多的子空间,对训练集的预测更准,近似误差更小了,但由于模型参数更多了(可能会受到过拟合的影响),在计算测试集的时候的估计误差更大了;K值越大,模型特征空间被划分成的子空间更少,对训练集的预测准确度降低了,近似误差增加了,但由于模型参数变少了,具有更好的泛化能力,对测试集的效果提升了,即估计误差变小了。在实际应用中,一般K取一个较小的值,例如采用交叉验证法(一部分作为训练集,一部分作为测试集)来选择最优的K值。

    (3)分类决策规则

     

    缺陷:由于朴素KNN对于每一个样本,都要计算其他样本与它的距离,计算效率低;

    KNN算法的描述为:

    1)计算测试数据与各个训练数据之间的距离;

    2)按照距离的递增关系进行排序;

    3)选取距离最小的K个点;

    4)确定前K个点所在类别的出现频率;

    5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。

    朴素KNN实现代码如下:(使用mnist数据集)

    #coding:UTF-8
    # 2018_03_17
    #朴素KNN实现
    #================================================================================================
    import cPickle as pickle
    import gzip
    import numpy as np
    import math
    import time
    #开始计时
    time_start=time.time()
    
    #记载数据
    def load_data(data_file):
        with gzip.open(data_file , 'rb') as f:
            train_set , valid_set , test_set = pickle.load(f)
        return train_set[0] , train_set[1] , test_set[0] , test_set[1]
    #计算欧式距离
    def cal_distance(x , y):
        return ((x - y) * (x - y).T)[0 , 0]
        #return math.sqrt((x - y) * (x - y))
    #得到预测值
    def get_prediction(train_y , result):
        result_dict = {}
        for i in range(len(result)):
            if train_y[result[i]] not in result_dict:
                result_dict[train_y[result[i]]] = 1
            else:
                result_dict[train_y[result[i]]] += 1
        predict = sorted(result_dict.items() , key = lambda d:d[1])
        return predict[0][0]
    #KNN算法
    def k_nn(train_data , train_y , test_data , k):
        m = np.shape(test_data)[0] #需要计算的样本个数
        m_train = np.shape(train_data)[0]
        predict = []
    
        for i in range(m):
            #对每一个需要计算的样本计算其与所有的训练数据之间的欧式距离
            distance_dict = {}
            for i_train in range(m_train):
                distance_dict[i_train] = cal_distance(train_data[i_train , :] , test_data[i , :])
                #对距离进行排序,得到最终的前k个作为最终的预测值
            distance_result = sorted(distance_dict.items() , key = lambda d : d[1])
            #取出前k个的结果作为最终的结果
            result = []
            count = 0
            for x in distance_result:
                if count >= k:
                    break
                result.append(x[0])
                count += 1
            #得到预测
            predict.append(get_prediction(train_y , result))
        return predict
    
    def get_correct_rate(result , test_y):
        m = len(result)
        correct = 0.0
        for i in range(m):
            if result[i] == test_y[i]:
                correct += 1
        print correct
        print m
        return correct / m
    
    if __name__ == '__main__':
        #1、导入
        print "------------1、load data----------------"
        train_x , train_y , test_x , test_y = load_data("../dataset/mnist.pkl.gz")
        #2、利用knn计算
        train_x = np.mat(train_x)
        test_x = np.mat(test_x)
        print "-------------2、K-NN---------------------"
        result = k_nn(train_x , train_y , test_x[: 10 , :] , 10)
        print result
        #3、预测正确性
        print "-------------3、correct rate-------------"
        print get_correct_rate(result , test_y)
        time_end = time.time()
        print('totally cost:', time_end - time_start)

    当K=10的时候,运行时间为12.36s,可见非常耗时;

    后期会贴出速度更快的方法:K-d树和局部敏感哈希表;

  • 相关阅读:
    js图片放大
    js编写点名器
    javascript中的math和随机数
    python中 __slots__
    python中 @property
    CentOS 6.5通过yum安装 MySQL-5.5
    linux下环境搭建
    oracle:ORA-01940无法删除当前已连接用户的解决方案
    不同版本apache(免安装)下部署Javaee多套项目
    使用poi处理excel
  • 原文地址:https://www.cnblogs.com/zf-blog/p/8590567.html
Copyright © 2011-2022 走看看