zoukankan      html  css  js  c++  java
  • k近邻算法(简单版)

    存在一个样本数据集,也称作训练样本集,并且样本中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系,输入没有标签的新数据后,将新数据的每个特征与样本集中的数据对应的特征进行比较,然后算法提取样本集中特征最相似的数据(最近邻)的分类标签。一般来说,我们只选择样本集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数,最后,选择k个最相似的数据中出现次数最多的分类,作为新数据的分类。

    KNN三要素:k值选择,距离度量分类决策规则(取均值的决策规则).

    import numpy as np
    from cmath import sqrt
    from collections import Counter
    import matplotlib.pyplot as plt
    
    #用来模拟的数据集
    raw_data_X = [[3.393533211, 2.331273381],
                  [3.110073483, 1.781539638],
                  [1.343808831, 3.368360954],
                  [3.582294042, 4.679179110],
                  [2.280362439, 2.866990263],
                  [7.423436942, 4.696522875],
                  [5.745051997, 3.533989803],
                  [9.172168622, 2.511101045],
                  [7.792783481, 3.424088941],
                  [7.939820817, 0.791637231]
                 ]
    raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
    
    #创建为numpy数组
    x_train = np.array(raw_data_X)
    y_train = np.array(raw_data_y)
    
    print(x_train)
    print(y_train)
    
    '''
    plt.scatter(x_train[y_train==0, 0], x_train[y_train==0, 1], color = 'g')
    plt.scatter(x_train[y_train==1, 0], x_train[y_train==1, 1], color = 'r')
    plt.show()
    '''
    
    x = np.array([8.093607318, 3.365731514])
    #绘制散点图(注: y是特征值,并不是纵坐标)(y=0的点标记为绿色,y=1的点标记为红色)
    plt.scatter(x_train[y_train==0, 0], x_train[y_train==0, 1], color = 'g')
    plt.scatter(x_train[y_train==1, 0], x_train[y_train==1, 1], color = 'r')
    plt.scatter(x[0], x[1], color = 'b')
    plt.show()
    
    #KNN的过程
    '''
    distances = []
    for X_train in x_train:
    	d = sqrt(np.sum((X_train - x)**2))
    	distances.append(d)
    print(distances)
    '''
    #计算点x到其余各点的欧拉距离
    distances = [sqrt(np.sum((X_train - x)**2)) for X_train in x_train]
    #print(distances)
    
    #print(np.argsort(distances))
    #argsort是numpy中提供的排序,返回排序后的索引
    nearest = np.argsort(distances)
    k = 6
    #找到距离x最近的前六个点的y值
    topK_y = [y_train[i] for i in nearest[:k]]
    #print(topK_y)
    
    #统计距离x最近的元素对应的种类(0/1)
    print(Counter(topK_y))
    votes = Counter(topK_y)
    #输出该元素最可能的种类
    #print(votes.most_common(1))
    predict_y = votes.most_common(1)[0][0]
    print(predict_y)
    

      

  • 相关阅读:
    Linux OpenSSH后门的加入与防范
    Oracle APEX 4.2安装和配置
    springboot 配置jsp支持
    java 多线程 yield方法的意义
    java多线程状态转换
    Jquery_artDialog对话框弹出
    ThinkPHP框架学习摘要
    js弹窗对象不能通过全局对象移到外部函数中执行
    关于rawurldecode PHP自动解码
    td高度不随内容变化display:block;display:block;display:block;display:block;display:block;
  • 原文地址:https://www.cnblogs.com/mjn1/p/11124675.html
Copyright © 2011-2022 走看看