zoukankan      html  css  js  c++  java
  • K近邻算法(二)

    def KNN_classify(k, X_train, y_train, x):
        assert 1 <= k <= X_train.shape[0], "k must be valid"
        assert X_train.shape[0] == y_train.shape[0], 
            "the size of X_train must equal to the size of y_train"
        assert X_train.shape[1] == x.shape[0], 
            "the feature number of x must be equal to X_train"
        # 求距离
        distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in X_train]
        nearest = np.argsort(distances)
        topK_y = [y_train[i] for i in nearest[:k]]
        votes = Counter(topK_y)
        return votes.most_common(1)[0][0]

     sklearn 库的使用

    from sklearn.neighbors import KNeighborsClassifier
    KNN_classifier = KNeighborsClassifier(n_neighbors=5) 
    #n_neighbors 即是k
    KNN_classifier.fit(X_train, y_train) 
    print(KNN_classifier.predict([x])) 
    # 说明predict传入参数应为矩阵,为了是批量预测。 
    # 若只有一个也要转成矩阵的形式 x.reshape(1,-1)
  • 相关阅读:
    uva1610 Party Games
    uva1442 Cav
    uva1609 Foul Play
    uva1608 Non-boring sequences
    uva12174 滑动窗口+预处理
    uva 1451 数形结合
    light oj 1336 sigma function
    找常用词(字符串处理)问题
    指定排序问题
    完数问题
  • 原文地址:https://www.cnblogs.com/infoo/p/9400736.html
Copyright © 2011-2022 走看看