zoukankan      html  css  js  c++  java
  • 机器学习:weka中Evaluation类源码解析及输出AUC及交叉验证介绍

      在机器学习分类结果的评估中,ROC曲线下的面积AOC是一个非常重要的指标。下面是调用weka类,输出AOC的源码:

    try {
    // 1.读入数据集
    
                    Instances data = new Instances(
                                          new BufferedReader(
                                            new FileReader("E:\Develop/Weka-3-6/data/contact-lenses.arff")));
    
                    data.setClassIndex(data.numAttributes() - 1);
    
    // 2.训练分类器并用十字交叉验证法来获得Evaluation对象
    // 注意这里的方法与我们在上几节中使用的验证法是不同。
                    Classifier cl = new NaiveBayes();
                    Evaluation eval = new Evaluation(data);
                    eval.crossValidateModel(cl, data, 10, new Random(1));
    
             
    // 3.生成用于得到ROC曲面和AUC值的Instances对象
           System.out.println(eval.toClassDetailsString());
                System.out.println(eval.toSummaryString());
                System.out.println(eval.toMatrixString()); }
    catch (Exception e) { e.printStackTrace(); }

      接着说一下交叉验证;

      如果没有分开训练集和测试集,可以使用Cross Validation方法,Evaluation中crossValidateModel方法的四个参数分别为,第一个是分类器,第二个是在某个数据集上评价的数据集,第三个参数是交叉检验的次数(10是比较常见的),第四个是一个随机数对象。

      注意:使用crossValidateModel时,分类器不需要先训练,否则buildClassifier方法会初始化分类器,交叉验证的配置结果就没有用了。

      类crossValidateModel的源码如下:

     public void crossValidateModel(Classifier classifier, Instances data,
        int numFolds, Random random, Object... forPredictionsPrinting)
        throws Exception {
    
        // Make a copy of the data we can reorder
        data = new Instances(data);
        data.randomize(random);
        if (data.classAttribute().isNominal()) {
          data.stratify(numFolds);
        }
    
        // We assume that the first element is a StringBuffer, the second a Range
        // (attributes
        // to output) and the third a Boolean (whether or not to output a
        // distribution instead
        // of just a classification)
        if (forPredictionsPrinting.length > 0) {
          // print the header first
          StringBuffer buff = (StringBuffer) forPredictionsPrinting[0];
          Range attsToOutput = (Range) forPredictionsPrinting[1];
          boolean printDist = ((Boolean) forPredictionsPrinting[2]).booleanValue();
          printClassificationsHeader(data, attsToOutput, printDist, buff);
        }
    
        // Do the folds
        for (int i = 0; i < numFolds; i++) {
          Instances train = data.trainCV(numFolds, i, random);
          setPriors(train);
          Classifier copiedClassifier = Classifier.makeCopy(classifier);
          copiedClassifier.buildClassifier(train);
          Instances test = data.testCV(numFolds, i);
          evaluateModel(copiedClassifier, test, forPredictionsPrinting);
        }
        m_NumFolds = numFolds;
      }

    输出结果截图:

    更新中。。。

    libsvm 下载地址 https://github.com/cjlin1/libsvm

        github地址   https://github.com/cjlin1/libsvm

  • 相关阅读:
    Web Components 是什么
    HAL_RTC_MspInit Msp指代什么?
    C 枚举 相同的值
    EntityFramework Core并发深挖详解,一纸长文,你准备好看完了吗?
    ASP.NET Core MVC之ViewComponents(视图组件)知多少?
    .NET Core 1.1日期解析无APi、SQL Server数据转换JSON
    SQL Server-字字珠玑,一纸详文,完全理解SERIALIZABLE最高隔离级别(基础系列收尾篇)
    SQL Server-聚焦NOLOCK、UPDLOCK、HOLDLOCK、READPAST你弄懂多少?(三十四)
    SQL Server-聚焦深入理解死锁以及避免死锁建议(三十三)
    ASP.NET Core MVC上传、导入、导出知多少
  • 原文地址:https://www.cnblogs.com/rongyux/p/5386120.html
Copyright © 2011-2022 走看看