zoukankan      html  css  js  c++  java
  • 统计学习方法(三)——K近邻法

    /*先把标题给写了、这样就能经常提醒自己*/

    1. k近邻算法

    k临近算法的过程,即对一个新的样本,找到特征空间中与其最近的k个样本,这k个样本多数属于某个类,就把这个新的样本也归为这个类。

    算法 

    输入:训练数据集

    其中为样本的特征向量,为实例的类别,i=1,2,…,N;样本特征向量x(新样本);

    输出:样本x所属的类y。

    (1)根据给定的距离度量,在训练集T中找出与x最相邻的k个点,涵盖这k个点的邻域记作

    (2)在中根据分类决策规则(如多数表决)决定x的类别y:

                                                                        (1)

    式中I为指示函数,即当时I为1,否则为0。

    由这个简单的算法过程可以看出来,距离的选择、以及k的选择都是很重要的,这恰好对应的三个要素中的两个,另一个为分类决策规则,一般来说是多数表决法。

     

    2. k近邻模型

    k近邻算法使用的模型实际上对应于特征空间的划分,模型由三个基本要素——距离度量、k值的选择和分类决策规则决定。

    距离度量

          特征空间中俩个实例的距离是俩个实例点相似程度的反映,k近邻中一般使用欧氏距离,本文中主要只介绍这一种。

    设特征空间维实数向量空间,,距离定义为

     

    当p=2时,称为欧氏距离(Euclidean distance).

    ==================在此吐槽一下,博客园的图片插入好折腾人啊,已经搞出肩周炎了,明天再继续码第二要素了 2014-6-30========================

    举个粟子,已知,,则  的欧氏距离为 ,挺容易理解的吧!

    K值的选择

    首先说明一下K值的选择对最终的结果有很大的影响!!!

          如果选择的k过小,则预测的结果对近邻的实例点非常敏感,如果近邻刚好是噪声,则预测就会出错,例如k=1,很难保证最近的一个点就是正确的预测,亦即容易发生过拟合!如果选择的k过大,则会忽略掉训练实例中的大量有用信息,例如k=N,那么无论输入实例是什么最终的结果都将是训练实例中最多的类。

          关于分类决策规则这里就不再赘述,正常情况下直接采用多数表决即可,如果觉得结果不满意的话,可以加入各个类的先验概率进去融合!

     

    3. K近邻的实现

    该小节书本中用到了KD树,通过构造平衡KD树来方便快速查找训练数据中离测试实例最近的点,不过构造这颗树本身是一个比较繁琐的过程(其实是本人代码能力实在太菜了,真的觉得把KD树写下来需要花太多时间了,而且KD树中每增加一个新数据又要进行节点插入操作,实在不方便,直接放弃),所以直接用最土豪的方法,时间复杂度差就差了,咱有的是CPU!!!

    在这里直接套用书中例子,不过实现上就用其它算法了。稍等,我勒个去!书中的例子只是用于构造KD树的,李航兄你不厚道啊,说好的K近邻怎么变成这样了,不能直接引用书中例子了,自己再编一个得了。

    例子:训练数据集中,正样本点有,负样本点有,现要求判断实例属于哪个类别,如下图所示:

          

    假设取K=3,则距离最近的3个点为,按照多数表决规则可得出应该属于正类。

      为了表示咱们不是拍脑袋给出的结果,下面给出具体的代码实现

    package org.juefan.knn;
    
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.Comparator;
    import java.util.HashMap;
    import java.util.Map;
    
    import org.juefan.basic.FileIO;
    import org.juefan.data.Data;
    
    public class SimpleKnn {
        
        public static final int K = 3;        
        public static int P = 2;        //距离函数的选择,P=2即欧氏距离
        
        public class LabelDistance{
            public double distance = 0;
            public int label;        
            public LabelDistance(double d, int l){
                distance = d;
                label = l;
            }
        }
        
        public sort compare = new sort();
        public class sort implements Comparator<LabelDistance> {
            public int compare(LabelDistance arg0, LabelDistance arg1) {
                return arg0.distance < arg1.distance ? -1 : 1;        //JDK1.7的新特性,返回值必须是一对正负数
            }
        }
        
        /**
         * 俩个实例间的距离函数
         * @param a
         * @param b
         * @return 返回距离值,如果俩个实例的维度不一致则返回一个极大值
         */
        public double getLdistance(Data a, Data b){
            if(a.x.size() != b.x.size())
                return Double.MAX_VALUE;
            double inner = 0;
            for(int i = 0; i < P; i++){
                inner += Math.pow((a.x.get(i) - b.x.get(i)) , P);
            }
            return Math.pow(inner, (double)1/P);    
        }
        
        /**
         * 计算实例与训练集的距离并返回最终判断结果
         * @param d 待判断实例
         * @param tran 训练集
         * @return 实例的判断结果
         */
        public int getLabelvalue(Data d, ArrayList<Data> tran){
            ArrayList<LabelDistance> labelDistances= new ArrayList<>();
            Map<Integer, Integer> map = new HashMap<>();
            int label = 0;
            int count = 0;
            for(Data data: tran){
                labelDistances.add(new LabelDistance(getLdistance(d, data), data.y));
            }
            Collections.sort(labelDistances, compare);
            for(int i = 0; i < K & i < labelDistances.size(); i++){
                //System.out.println(labelDistances.get(i).distance + "	" + labelDistances.get(i).label);
                int tmplabel = labelDistances.get(i).label;
                if(map.containsKey(tmplabel)){
                    map.put(tmplabel, map.get(tmplabel) + 1);
                }else {
                    map.put(tmplabel, 1);
                }
            }
            for(int key: map.keySet()){
                if(map.get(key) > count){
                    count = map.get(key);
                    label = key;
                }
            }
            return label;    
        }
        
        public static void main(String[] args) {
            SimpleKnn knn = new SimpleKnn();
            ArrayList<Data> datas = new ArrayList<>();
            FileIO fileIO = new FileIO();
            fileIO.setFileName(".//file//knn.txt");
            fileIO.FileRead();
            for(String data: fileIO.fileList){
                datas.add(new Data(data));
            }
            Data data = new Data();
            data.x.add(2); data.x.add(1);
            System.out.println(knn.getLabelvalue(data, datas));
        }
    }

    对代码有兴趣的可以上本人的GitHub查看:https://github.com/JueFan/StatisticsLearningMethod/

  • 相关阅读:
    《Ubuntu标准教程》学习总结
    Ubuntu下安装VirtualBox并为其添加USB支持
    Eclipse下配置TinyOS开发环境
    Ubuntu下的网络服务
    Ubuntu12.04添加环境变量
    Ubuntu12.04下搭建Java环境
    poj 1066 线段相交
    poj 2653 (线段相交判断)
    poj 2398 (叉积+二分)
    hdu 4883 思维题
  • 原文地址:https://www.cnblogs.com/juefan/p/3807713.html
Copyright © 2011-2022 走看看