zoukankan      html  css  js  c++  java
  • 机器学习——k近邻算法(KNN)

    import math
    import numpy as np
    from collections import Counter
    class KNNClassfiy(object):
        def __init__(self,k):
        #判断k有效
            assert k>=1,'k must be valid'
            self.k=k
            self._xTrain=None
            self._yTrain=None
    
    
        def fit(self,xTrain,yTrain):
        #判断输入的训练集有效
            assert xTrain.shape[0]==yTrain.shape[0],
                'The size of xTrain must be equals to the size of yTrain'
        #判断K有效   
            assert self.k<=xTrain.shape[0],
                'The size of xTrain must be least at k'
            self._xTrain=xTrain
            self._yTrain=yTrain
            return self
    
        def predict(self,X_predict):
            # X_predict是预测数据数组,判断预测数据合法性,必须是二维数组
            assert X_predict.shape[1]==self._xTrain.shape[1],
                'The feature of x must be equal to xTrain'
            assert self._xTrain is not None and self._yTrain is not None,
                'must fit before predict'
            y_predict=[self._predict(x) for  x in X_predict]
            return np.array(y_predict)
    
        def _predict(self,x):
            distances=[math.sqrt(np.sum((xTrain-x)**2)) for xTrain in self._xTrain]
            nearest=np.argsort(distances)
            top_y=[self._yTrain[i] for i in nearest[:self.k]]
            votes=Counter(top_y)
            print(votes.most_common(1))
            return votes.most_common(1)[0][0]
        def __repr__(self):
            return self.k
    
    KNN_clf=KNNClassfiy(k=6);
    #先训练后预测
    xTrain=np.array([[4.5,3.2],
                     [5.8,4.1],
                     [6.7,5.3],
                     [8.6,7.1],
                     [3.8,2.5],
                     [5.3,4.4],
                     [9.4,8.6],
                     [11.8,9.4],
                     [3.8,3.2],
                     [12.8,10.1]])
    yTrain=np.array([0,0,1,1,0,0,1,1,0,1])
    KNN_clf.fit(xTrain=xTrain,yTrain=yTrain)
    x_predict=np.array([[6.9,5.7],[3.4,2.8]])
    a=KNN_clf.predict(x_predict)
    print(a[0],a[1])
    

    代码比较简单,主要逻辑在于预测部分。

    调用matplotlib绘制图形分布图

    在这里插入图片描述

    步骤可简化如下:

    • 确定k值
    • 训练数据集
    • 预测函数

    K近邻算法主要解决分类问题,是机器学习中最简单的最基础的一种算法。

  • 相关阅读:
    MongoDB学习笔记-查询
    【ASP.NET MVC 回顾】HtmlHepler应用-分页组件
    浅谈.NET中闭包
    浅析 public static void main(String[] args)
    关于SQL Server 无法生成 FRunCM 线程(不完全)
    设计模式-02.单例模式
    设计模式-01.工厂模式
    GC垃圾回收机制
    Spring自学笔记
    关于面试
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13309462.html
Copyright © 2011-2022 走看看