zoukankan      html  css  js  c++  java
  • KNN算法基本实例

      KNN算法是机器学习领域中一个最基本的经典算法。它属于无监督学习领域的算法并且在模式识别,数据挖掘和特征提取领域有着广泛的应用。

    给定一些预处理数据,通过一个属性把这些分类坐标分成不同的组。这就是KNN的思路。

      下面,举个例子来说明一下。图中的数据点包含两个特征:

      现在,给出数据点的另外一个节点,通过分析训练节点来把这些节点分类。没有分来的及诶但我们标记为白色,如下所示:

      直观来讲,如果我们把那些节点花道一个图片上,我们可能就能确定一些特征,或组。现在,给一个没有分类的点,我们可以通过观察它距离那个组位置最近来确定它属于哪个组。意思就是,假如一个点距离红色的组最近,我们就可以把这个点归为红色的组。简而言之,我们可以把第一个点(2.5,7)归类为绿色,把第二个点(5.5,4.5)归类为红色。

      算法流程:

      假设m是训练样本的数量,p是一个未知的节点。

      1 把所有训练的样本放到也数组arr[]中。这个意思就是这个数组中每个元素就可以使用元组(x,y)表示。

      2 伪码

    for i=0 to m:
      Calculate Euclidean distance d(arr[i], p).

      3 标记设置S为K的最小距离。这里每个距离都和一个已经分类的数据点相关。

      4 返回在S之间的大多数标签。

      实际程序C代码:

     

    // C++ program to find groups of unknown
    // Points using K nearest neighbour algorithm.
    #include <bits/stdc++.h>
    using namespace std;
     
    struct Point
    {
        int val;     // Group of point
        double x, y;     // Co-ordinate of point
        double distance; // Distance from test point
    };
     
    // Used to sort an array of points by increasing
    // order of distance
    bool comparison(Point a, Point b)
    {
        return (a.distance < b.distance);
    }
     
    // This function finds classification of point p using
    // k nearest neighbour algorithm. It assumes only two
    // groups and returns 0 if p belongs to group 0, else
    // 1 (belongs to group 1).
    int classifyAPoint(Point arr[], int n, int k, Point p)
    {
        // Fill distances of all points from p
        for (int i = 0; i < n; i++)
            arr[i].distance =
                sqrt((arr[i].x - p.x) * (arr[i].x - p.x) +
                     (arr[i].y - p.y) * (arr[i].y - p.y));
     
        // Sort the Points by distance from p
        sort(arr, arr+n, comparison);
     
        // Now consider the first k elements and only
        // two groups
        int freq1 = 0;     // Frequency of group 0
        int freq2 = 0;     // Frequency of group 1
        for (int i = 0; i < k; i++)
        {
            if (arr[i].val == 0)
                freq1++;
            else if (arr[i].val == 1)
                freq2++;
        }
     
        return (freq1 > freq2 ? 0 : 1);
    }
     
    // Driver code
    int main()
    {
        int n = 17; // Number of data points
        Point arr[n];
     
        arr[0].x = 1;
        arr[0].y = 12;
        arr[0].val = 0;
     
        arr[1].x = 2;
        arr[1].y = 5;
        arr[1].val = 0;
     
        arr[2].x = 5;
        arr[2].y = 3;
        arr[2].val = 1;
     
        arr[3].x = 3;
        arr[3].y = 2;
        arr[3].val = 1;
     
        arr[4].x = 3;
        arr[4].y = 6;
        arr[4].val = 0;
     
        arr[5].x = 1.5;
        arr[5].y = 9;
        arr[5].val = 1;
     
        arr[6].x = 7;
        arr[6].y = 2;
        arr[6].val = 1;
     
        arr[7].x = 6;
        arr[7].y = 1;
        arr[7].val = 1;
     
        arr[8].x = 3.8;
        arr[8].y = 3;
        arr[8].val = 1;
     
        arr[9].x = 3;
        arr[9].y = 10;
        arr[9].val = 0;
     
        arr[10].x = 5.6;
        arr[10].y = 4;
        arr[10].val = 1;
     
        arr[11].x = 4;
        arr[11].y = 2;
        arr[11].val = 1;
     
        arr[12].x = 3.5;
        arr[12].y = 8;
        arr[12].val = 0;
     
        arr[13].x = 2;
        arr[13].y = 11;
        arr[13].val = 0;
     
        arr[14].x = 2;
        arr[14].y = 5;
        arr[14].val = 1;
     
        arr[15].x = 2;
        arr[15].y = 9;
        arr[15].val = 0;
     
        arr[16].x = 1;
        arr[16].y = 7;
        arr[16].val = 0;
     
        /*Testing Point*/
        Point p;
        p.x = 2.5;
        p.y = 7;
     
        // Parameter to decide groupr of the testing point
        int k = 3;
        printf ("The value classified to unknown point"
                " is %d.
    ", classifyAPoint(arr, n, k, p));
        return 0;
    }
    View Code

      实际程序python代码:

      

     1 # Python3 program to find groups of unknown
     2 # Points using K nearest neighbour algorithm.
     3  
     4 import math
     5  
     6 def classifyAPoint(points,p,k=3):
     7     '''
     8      This function finds classification of p using
     9      k nearest neighbour algorithm. It assumes only two
    10      groups and returns 0 if p belongs to group 0, else
    11       1 (belongs to group 1).
    12  
    13       Parameters - 
    14           points : Dictionary of training points having two keys - 0 and 1
    15                    Each key have a list of training data points belong to that 
    16  
    17           p : A touple ,test data point of form (x,y)
    18  
    19           k : number of nearest neighbour to consider, default is 3 
    20     '''
    21  
    22     distance=[]
    23     for group in points:
    24         for feature in points[group]:
    25  
    26             #calculate the euclidean distance of p from training points 
    27             euclidean_distance = math.sqrt((feature[0]-p[0])**2 +(feature[1]-p[1])**2)
    28  
    29             # Add a touple of form (distance,group) in the distance list
    30             distance.append((euclidean_distance,group))
    31  
    32     # sort the distance list in ascending order
    33     # and select first k distances
    34     distance = sorted(distance)[:k]
    35  
    36     freq1 = 0 #frequency of group 0
    37     freq2 = 0 #frequency og group 1
    38  
    39     for d in distance:
    40         if d[1] == 0:
    41             freq1 += 1
    42         elif d[1] == 1:
    43             freq2 += 1
    44  
    45     return 0 if freq1>freq2 else 1
    46  
    47 # driver function
    48 def main():
    49  
    50     # Dictionary of training points having two keys - 0 and 1
    51     # key 0 have points belong to class 0
    52     # key 1 have points belong to class 1
    53  
    54     points = {0:[(1,12),(2,5),(3,6),(3,10),(3.5,8),(2,11),(2,9),(1,7)],
    55               1:[(5,3),(3,2),(1.5,9),(7,2),(6,1),(3.8,1),(5.6,4),(4,2),(2,5)]}
    56  
    57     # testing point p(x,y)
    58     p = (2.5,7)
    59  
    60     # Number of neighbours 
    61     k = 3
    62  
    63     print("The value classified to unknown point is: {}".
    64           format(classifyAPoint(points,p,k)))
    65  
    66 if __name__ == '__main__':
    67     main()
    68      
    69 # This code is contributed by Atul Kumar (www.fb.com/atul.kr.007)
    View Code

      

      

  • 相关阅读:
    05流程图和流程定义的操作
    04启动流程实例,任务的查询与完成
    03流程图的绘制与部署
    02数据库表的初始化方式
    01环境安装
    JavaScript基础和JavaScript内置对象:
    用手机、pid作为win电脑扩展屏
    H5新增特性之语义化标签
    盒模型
    CSS定位总结--static、relative、absolute、fixed
  • 原文地址:https://www.cnblogs.com/dylancao/p/9150342.html
Copyright © 2011-2022 走看看