zoukankan      html  css  js  c++  java
  • 机器学习:IB1算法的weka源码详细解析(1NN)

      机器学习的1NN最近邻算法,在weka里叫IB1,是因为Instance Base  1 ,也就是只基于一个最近邻的实例的惰性学习算法。

      下面总结一下,weka中对IB1源码的学习总结。

      首先需要把 weka-src.jar 引入编译路径,否则无法跟踪源码。

      1)读取data数据,完成 IB1 分类器的调用,结果预测评估。为了后面的跟踪。

    try {
                File file = new File("F:\tools/lib/data/contact-lenses.arff");
    
                ArffLoader loader = new ArffLoader();
                loader.setFile(file);
                ins = loader.getDataSet();
    
                // 在使用样本之前一定要首先设置instances的classIndex,否则在使用instances对象是会抛出异常
                ins.setClassIndex(ins.numAttributes() - 1);
                
                cfs = new IB1();
                cfs.buildClassifier(ins);
                            
                Instance testInst;
                Evaluation testingEvaluation = new Evaluation(ins);
                int length = ins.numInstances();
                for (int i = 0; i < length; i++) {
                    testInst = ins.instance(i);
                    // 通过这个方法来用每个测试样本测试分类器的效果
                    double predictValue = cfs.classifyInstance(testInst);
                    
                    System.out.println(testInst.classValue()+"--"+predictValue);
                }
    
               // System.out.println("分类器的正确率:" + (1 - testingEvaluation.errorRate()));
    
            } catch (Exception e) {
                e.printStackTrace();
            }

    2)ctrl 点击buildClassifier,进一步跟踪buildClassifier方法的源码,在IB1的类中重写了这个抽象方法,源码为:

    public void buildClassifier(Instances instances) throws Exception {
        
        // can classifier handle the data?
        getCapabilities().testWithFail(instances);
    
        // remove instances with missing class
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        
        m_Train = new Instances(instances, 0, instances.numInstances());
    
        m_MinArray = new double [m_Train.numAttributes()];
        m_MaxArray = new double [m_Train.numAttributes()];
        for (int i = 0; i < m_Train.numAttributes(); i++) {
          m_MinArray[i] = m_MaxArray[i] = Double.NaN;
        }
        Enumeration enu = m_Train.enumerateInstances();
        while (enu.hasMoreElements()) {
          updateMinMax((Instance) enu.nextElement());
        }
      }

      (1)if是判断,IB1分类器不能处理属性是字符串和类别是数值型的样本;

      (2)if是判断,删除没有类标签的样本;

      (3)m_MinArray 和 m_MaxArray 分别保存最小和最大值,并且初始化double数组【样本个数】;

      (4)遍历所有的训练样本实例,求最小和最大值;继续跟踪updateMinMax方法;

      3)IB1类的updateMinMax方法的源码如下:

      private void updateMinMax(Instance instance) {
        
        for (int j = 0;j < m_Train.numAttributes(); j++) {
          if ((m_Train.attribute(j).isNumeric()) && (!instance.isMissing(j))) {
        if (Double.isNaN(m_MinArray[j])) {
          m_MinArray[j] = instance.value(j);
          m_MaxArray[j] = instance.value(j);
        } else {
          if (instance.value(j) < m_MinArray[j]) {
            m_MinArray[j] = instance.value(j);
          } else {
            if (instance.value(j) > m_MaxArray[j]) {
              m_MaxArray[j] = instance.value(j);
            }
          }
        }
          }
        }
      }

      (1)过滤掉属性不是数值型和缺失标签的实例;

      (2)若是isNaN,is not a number,是数值型的话,循环遍历样本的每一个属性,求出最大最小值;

      到此为止,训练了IB1模型(有人可能会问lazy的算法难道不是不需要训练模型吗?我认为build分类器是为了初始化 m_Train和求所有实例的每个属性的最大最小值,为了下一步求distance做准备)

    下面介绍下预测源码:

      

      4)跟踪classifyInstance方法,源码如下:

     public double classifyInstance(Instance instance) throws Exception {
        
        if (m_Train.numInstances() == 0) {
          throw new Exception("No training instances!");
        }
    
        double distance, minDistance = Double.MAX_VALUE, classValue = 0;
        updateMinMax(instance);
        Enumeration enu = m_Train.enumerateInstances();
        while (enu.hasMoreElements()) {
          Instance trainInstance = (Instance) enu.nextElement();
          if (!trainInstance.classIsMissing()) {
        distance = distance(instance, trainInstance);
        if (distance < minDistance) {
          minDistance = distance;
          classValue = trainInstance.classValue();
        }
          }
        }
    
        return classValue;
      }

      (1)调用方法updateMinMax更新了加入测试实例后的最大最小值;

      (2)计算测试实例到每一个训练实例的距离,distance方法,并且保存距离最小的实例minDistance;

      5)跟踪classifyInstance方法,源码如下:

     private double distance(Instance first, Instance second) {
        
        double diff, distance = 0;
    
        for(int i = 0; i < m_Train.numAttributes(); i++) { 
          if (i == m_Train.classIndex()) {
        continue;
          }
          if (m_Train.attribute(i).isNominal()) {
    
        // If attribute is nominal
        if (first.isMissing(i) || second.isMissing(i) ||
            ((int)first.value(i) != (int)second.value(i))) {
          distance += 1;
        }
          } else {
        
        // If attribute is numeric
        if (first.isMissing(i) || second.isMissing(i)){
          if (first.isMissing(i) && second.isMissing(i)) {
            diff = 1;
          } else {
            if (second.isMissing(i)) {
              diff = norm(first.value(i), i);
            } else {
              diff = norm(second.value(i), i);
            }
            if (diff < 0.5) {
              diff = 1.0 - diff;
            }
          }
        } else {
          diff = norm(first.value(i), i) - norm(second.value(i), i);
        }
        distance += diff * diff;
          }
        }
        
        return distance;
      }

      对每一个属性遍历,计算数值属性距离的平方和,norm方法为规范化距离公式,为【0,1】的实数  

      6)跟踪norm规范化方法,源码如下:

      private double norm(double x,int i) {
    
        if (Double.isNaN(m_MinArray[i])
        || Utils.eq(m_MaxArray[i], m_MinArray[i])) {
          return 0;
        } else {
          return (x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]);
        }
      }

      规范化距离:(x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]);

      

     具体的算法伪代码,请查找最近邻分类器的论文,我就不贴出来了。

  • 相关阅读:
    php 建立类POST/GET 的HTTP请求
    上传文件
    golang精选100题带答案
    go面试
    golang反射
    go语言中type的几种使用
    写个版本迭代的方法 例如1.0.9 迭代为1.1.0 到10自动往前进1
    压缩文件和解压文件
    go语言中的文件创建,写入,读取,删除
    go面试题
  • 原文地址:https://www.cnblogs.com/rongyux/p/5371159.html
Copyright © 2011-2022 走看看