zoukankan      html  css  js  c++  java
  • 08.手写KNN算法测试

    导入库

    import numpy as np
    from sklearn import datasets
    import matplotlib.pyplot as plt

    导入数据

    iris = datasets.load_iris()

    数据准备

    X = iris.data
    y = iris.target
    X.shape, y.shape
    ((150, 4), (150,))

    数据分割(28开)

    # 因为训练集矩阵和标签向量是分割的,不能单独对某一个进行乱序
    # 需要将其合并整体乱序再分割

    X_join_y = np.hstack([X, y.reshape(-1,1)])

    # 随机,导致每次数据分割结果都会改变
    # 如果有debug需求,需要保证每次运行的分割结果一致
    # 则需要对random进行seed设置

    np.random.seed(1)
    np.random.shuffle(X_join_y)
    train,test = np.vsplit(X_join_y, [int(0.8*len(X_join_y))])
    train.shape,test.shape
    ((120, 5), (30, 5))

    准备data和target

    # X_train, y_train, X_test, y_test 成功拿到了训练集(数据+标签)和测试集(数据+标签)

    X_train = train[:,0:4]
    y_train = train[:,-1]
    X_test = test[:,0:4]
    y_test = test[:,-1]

    KNN手写算法

    import numpy as np
    from math import sqrt
    from collections import Counter
    class KNNClassifier: def __init__(self, k): # 初始化KNN分类器 self.k = k self._X_train = None self._y_train = None def fit(self, X_train, y_train): # 根据训练集X_train, Y_train训练分类器 self._X_train = X_train self._y_train = y_train return self def predict(self, X_predict): # 给定待遇测的数据集X_predict,返回表示X_predict的结果向量 y_predict = [self._predict(x) for x in X_predict] return np.array(y_predict) def _predict(self, x): # 给定单个待遇测数据x,返回x的预测结果值 distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._X_train] nearest = np.argsort(distances) topK_y = [self._y_train[i] for i in nearest[:self.k]] votes = Counter(topK_y) return votes.most_common(1)[0][0] def __repr__(self): return "KNN=(%d)" % self.k
    from sklearn.model_selection import train_test_split
    
    
    result = train_test_split(X, y)
    result
    [array([[7.2, 3. , 5.8, 1.6],
            [5.4, 3.9, 1.3, 0.4],
            [6.5, 3.2, 5.1, 2. ],
            [6.1, 3. , 4.6, 1.4],
            [4.6, 3.2, 1.4, 0.2],
            [6.9, 3.2, 5.7, 2.3],
            [6.1, 2.8, 4. , 1.3],
            [5.7, 3. , 4.2, 1.2],
            [5.8, 2.7, 4.1, 1. ],
            [5.5, 2.5, 4. , 1.3],
            [5.7, 2.5, 5. , 2. ],
            [4.6, 3.4, 1.4, 0.3],
            [5.9, 3.2, 4.8, 1.8],
            [6.3, 2.9, 5.6, 1.8],
            [6.8, 3. , 5.5, 2.1],
            [6.4, 2.7, 5.3, 1.9],
            [6. , 2.9, 4.5, 1.5],
            [6. , 2.2, 4. , 1. ],
            [4.8, 3. , 1.4, 0.1],
            [5.6, 2.5, 3.9, 1.1],
            [7.1, 3. , 5.9, 2.1],
            [6.7, 3.3, 5.7, 2.1],
            [5.5, 2.6, 4.4, 1.2],
            [6.3, 3.3, 4.7, 1.6],
            [6.7, 3.1, 4.7, 1.5],
            [4.3, 3. , 1.1, 0.1],
            [4.8, 3.4, 1.9, 0.2],
            [6.7, 3.3, 5.7, 2.5],
            [6. , 2.7, 5.1, 1.6],
            [6.5, 3. , 5.5, 1.8],
            [4.9, 2.5, 4.5, 1.7],
            [5. , 3.5, 1.3, 0.3],
            [5.9, 3. , 4.2, 1.5],
            [5.5, 2.4, 3.8, 1.1],
            [6.2, 2.2, 4.5, 1.5],
            [6.3, 2.7, 4.9, 1.8],
            [4.4, 3. , 1.3, 0.2],
            [7.7, 3. , 6.1, 2.3],
            [7. , 3.2, 4.7, 1.4],
            [6.4, 2.8, 5.6, 2.2],
            [5.7, 2.8, 4.5, 1.3],
            [6.4, 2.9, 4.3, 1.3],
            [5.6, 3. , 4.1, 1.3],
            [6.3, 2.8, 5.1, 1.5],
            [4.9, 3.6, 1.4, 0.1],
            [6. , 3.4, 4.5, 1.6],
            [5.7, 4.4, 1.5, 0.4],
            [4.8, 3. , 1.4, 0.3],
            [5.4, 3.7, 1.5, 0.2],
            [5.4, 3.4, 1.5, 0.4],
            [5. , 2.3, 3.3, 1. ],
            [6.9, 3.1, 4.9, 1.5],
            [5.1, 3.8, 1.9, 0.4],
            [6.4, 2.8, 5.6, 2.1],
            [5.1, 3.8, 1.5, 0.3],
            [5. , 3.4, 1.5, 0.2],
            [5.1, 3.3, 1.7, 0.5],
            [5.2, 2.7, 3.9, 1.4],
            [6.1, 2.6, 5.6, 1.4],
            [7.7, 2.8, 6.7, 2. ],
            [5.8, 2.7, 5.1, 1.9],
            [6.8, 2.8, 4.8, 1.4],
            [4.4, 3.2, 1.3, 0.2],
            [5.3, 3.7, 1.5, 0.2],
            [6.9, 3.1, 5.4, 2.1],
            [5.1, 2.5, 3. , 1.1],
            [5.7, 2.8, 4.1, 1.3],
            [6.4, 3.1, 5.5, 1.8],
            [6.2, 3.4, 5.4, 2.3],
            [5.8, 2.7, 5.1, 1.9],
            [6.3, 2.5, 4.9, 1.5],
            [5.8, 2.6, 4. , 1.2],
            [4.6, 3.1, 1.5, 0.2],
            [4.9, 3.1, 1.5, 0.2],
            [5.6, 2.9, 3.6, 1.3],
            [5.1, 3.7, 1.5, 0.4],
            [5. , 3.2, 1.2, 0.2],
            [6.5, 3. , 5.8, 2.2],
            [7.3, 2.9, 6.3, 1.8],
            [5.2, 3.4, 1.4, 0.2],
            [4.5, 2.3, 1.3, 0.3],
            [5.5, 2.3, 4. , 1.3],
            [6.5, 3. , 5.2, 2. ],
            [5.5, 2.4, 3.7, 1. ],
            [7.6, 3. , 6.6, 2.1],
            [5. , 3.6, 1.4, 0.2],
            [5.9, 3. , 5.1, 1.8],
            [6.3, 2.5, 5. , 1.9],
            [6.1, 3. , 4.9, 1.8],
            [4.9, 3. , 1.4, 0.2],
            [6.7, 3. , 5.2, 2.3],
            [5.1, 3.5, 1.4, 0.3],
            [6.3, 2.3, 4.4, 1.3],
            [4.4, 2.9, 1.4, 0.2],
            [6.8, 3.2, 5.9, 2.3],
            [5.1, 3.8, 1.6, 0.2],
            [7.2, 3.6, 6.1, 2.5],
            [5.7, 3.8, 1.7, 0.3],
            [5. , 2. , 3.5, 1. ],
            [5. , 3. , 1.6, 0.2],
            [4.8, 3.4, 1.6, 0.2],
            [4.8, 3.1, 1.6, 0.2],
            [6.7, 3.1, 5.6, 2.4],
            [5.8, 2.8, 5.1, 2.4],
            [5.8, 4. , 1.2, 0.2],
            [6.1, 2.8, 4.7, 1.2],
            [5.4, 3.9, 1.7, 0.4],
            [6.5, 2.8, 4.6, 1.5],
            [4.9, 3.1, 1.5, 0.1],
            [5.4, 3.4, 1.7, 0.2],
            [4.9, 2.4, 3.3, 1. ],
            [5.1, 3.4, 1.5, 0.2]]),
     array([[6.2, 2.9, 4.3, 1.3],
            [6.7, 3. , 5. , 1.7],
            [5.2, 4.1, 1.5, 0.1],
            [5.7, 2.6, 3.5, 1. ],
            [7.4, 2.8, 6.1, 1.9],
            [5.6, 3. , 4.5, 1.5],
            [6.9, 3.1, 5.1, 2.3],
            [6. , 2.2, 5. , 1.5],
            [5.5, 3.5, 1.3, 0.2],
            [6.7, 2.5, 5.8, 1.8],
            [7.2, 3.2, 6. , 1.8],
            [6. , 3. , 4.8, 1.8],
            [5.2, 3.5, 1.5, 0.2],
            [5.1, 3.5, 1.4, 0.2],
            [5. , 3.3, 1.4, 0.2],
            [5.6, 2.8, 4.9, 2. ],
            [5.6, 2.7, 4.2, 1.3],
            [5. , 3.5, 1.6, 0.6],
            [7.9, 3.8, 6.4, 2. ],
            [6.3, 3.4, 5.6, 2.4],
            [5. , 3.4, 1.6, 0.4],
            [6.2, 2.8, 4.8, 1.8],
            [5.4, 3. , 4.5, 1.5],
            [5.5, 4.2, 1.4, 0.2],
            [4.6, 3.6, 1. , 0.2],
            [6.1, 2.9, 4.7, 1.4],
            [6.4, 3.2, 5.3, 2.3],
            [5.7, 2.9, 4.2, 1.3],
            [7.7, 2.6, 6.9, 2.3],
            [7.7, 3.8, 6.7, 2.2],
            [6.3, 3.3, 6. , 2.5],
            [5.8, 2.7, 3.9, 1.2],
            [6.6, 2.9, 4.6, 1.3],
            [4.7, 3.2, 1.6, 0.2],
            [6.7, 3.1, 4.4, 1.4],
            [6.4, 3.2, 4.5, 1.5],
            [4.7, 3.2, 1.3, 0.2],
            [6.6, 3. , 4.4, 1.4]]),
     array([2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 2,
            1, 1, 1, 0, 0, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2,
            0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 2, 1, 0, 0, 2, 1,
            1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 2,
            2, 0, 2, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, 1, 0, 0,
            1, 0]),
     array([1, 1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 0, 2, 1, 0, 2, 2, 0, 2,
            1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 1, 0, 1, 1, 0, 1])]
    my_knn_clf = KNNClassifier(k=3)
    my_knn_clf.fit(result[0], result[2])
    KNN=(3)

    y_predict = my_knn_clf.predict(result[1])
    sum(y_predict == result[3])
    sum(y_predict == result[3])/len(result[3])
  • 相关阅读:
    FSCapture 取色工具(绿色版 )
    Java EE.JavaBean
    Java EE.JSP.内置对象
    Java EE.JSP.动作组件
    Java EE.JSP.指令
    Java EE.JSP.脚本
    21、多态与多态性、内置方法、反射、异常处理
    20、继承的应用、super、组合
    19、property、绑定方法(classmethod、staticmethod)、继承
    18、类
  • 原文地址:https://www.cnblogs.com/waterr/p/14039173.html
Copyright © 2011-2022 走看看