zoukankan      html  css  js  c++  java
  • 使用weka进行Cross-validation实验

    Generating cross-validation folds (Java approach)

    文献:

    http://weka.wikispaces.com/Generating+cross-validation+folds+%28Java+approach%29


    This article describes how to generate train/test splits for cross-validation using the Weka API directly. 

    The following variables are given:

     Instances data =  ...;   // contains the full dataset we wann create train/test sets from

     int seed = ...;          // the seed for randomizing the data

     int folds = ...;         // the number of folds to generate, >=2



     Randomize the data

    First, randomize your data:

     Random rand = new Random(seed);   // create seeded number generator

     randData = new Instances(data);   // create copy of original data

     randData.randomize(rand);         // randomize data with number generator

    In case your data has a nominal class and you wanna perform stratified cross-validation:

     randData.stratify(folds);



     Generate the folds

     Single run

    Next thing that we have to do is creating the train and the test set:

     for (int n = 0; n < folds; n++) {

       Instances train = randData.trainCV(folds, n);

       Instances test = randData.testCV(folds, n);

     

       // further processing, classification, etc.

       ...

     }

    Note:

    • the above code is used by the weka.filters.supervised.instance.StratifiedRemoveFolds filter
    • the weka.classifiers.Evaluation class and the Explorer/Experimenter would use this method for obtaining the train set:

     Instances train = randData.trainCV(folds, n, rand);



     Multiple runs

    The example above only performs one run of a cross-validation. In case you want to run 10 runs of 10-fold cross-validation, use the following loop:

     Instances data = ...;  // our dataset again, obtained from somewhere

     int runs = 10;

     for (int i = 0; i < runs; i++) {

       seed = i+1;  // every run gets a new, but defined seed value

     

       // see: randomize the data

       ...

     

       // see: generate the folds

       ...

     }

    一个简单的小实验:

    继续对上一节中的红酒和白酒进行分类。分类器没有变化,只是增加了重复试验过程

    package assignment2;
    
    import weka.core.Instances;
    
    import weka.core.converters.ConverterUtils.DataSource;
    
    import weka.core.Utils;
    
    import weka.classifiers.Classifier;
    
    import weka.classifiers.Evaluation;
    
    import weka.classifiers.trees.J48;
    
    import weka.filters.Filter;
    
    import weka.filters.unsupervised.attribute.Remove;
    
     
    
    import java.io.FileReader;
    
    import java.util.Random;
    
    public class cv_rw {
    
        public static Instances getFileInstances(String filename) throws Exception{
    
           FileReader frData =new FileReader(filename);
    
           Instances data = new Instances(frData);
    
           int length= data.numAttributes();
    
           String[] options = new String[2];
    
           options[0]="-R";
    
           options[1]=Integer.toString(length);
    
           Remove remove =new Remove();
    
           remove.setOptions(options);
    
           remove.setInputFormat(data);
    
           Instances newData= Filter.useFilter(data, remove);
    
           return newData;
    
        }
    
        public static void main(String[] args) throws Exception {
    
            // loads data and set class index
    
           Instances data = getFileInstances("D://Weka_tutorial//WineQuality//RedWhiteWine.arff");
    
    //     System.out.println(instances);
    
           data.setClassIndex(data.numAttributes()-1);
    
     
    
            // classifier
    
    //      String[] tmpOptions;
    
    //      String classname;
    
    //      tmpOptions     = Utils.splitOptions(Utils.getOption("W", args));
    
    //      classname      = tmpOptions[0];
    
    //      tmpOptions[0]  = "";
    
    //      Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);
    
    //
    
    //      // other options
    
    //      int runs  = Integer.parseInt(Utils.getOption("r", args));//重复试验
    
    //      int folds = Integer.parseInt(Utils.getOption("x", args));
    
           int runs=1;
    
           int folds=10;
    
           J48 j48= new J48();
    
    //     j48.buildClassifier(instances);
    
     
    
            // perform cross-validation
    
            for (int i = 0; i < runs; i++) {
    
              // randomize data
    
              int seed = i + 1;
    
              Random rand = new Random(seed);
    
              Instances randData = new Instances(data);
    
              randData.randomize(rand);
    
    //        if (randData.classAttribute().isNominal())    //没看懂这里什么意思,往高手回复,万分感谢
    
    //          randData.stratify(folds);
    
     
    
              Evaluation eval = new Evaluation(randData);
    
              for (int n = 0; n < folds; n++) {
    
                Instances train = randData.trainCV(folds, n);
    
                Instances test = randData.testCV(folds, n);
    
                // the above code is used by the StratifiedRemoveFolds filter, the
    
                // code below by the Explorer/Experimenter:
    
                // Instances train = randData.trainCV(folds, n, rand);
    
     
    
                // build and evaluate classifier
    
                Classifier j48Copy = Classifier.makeCopy(j48);
    
                j48Copy.buildClassifier(train);
    
                eval.evaluateModel(j48Copy, test);
    
              }
    
     
    
              // output evaluation
    
              System.out.println();
    
              System.out.println("=== Setup run " + (i+1) + " ===");
    
              System.out.println("Classifier: " + j48.getClass().getName());
    
              System.out.println("Dataset: " + data.relationName());
    
              System.out.println("Folds: " + folds);
    
              System.out.println("Seed: " + seed);
    
              System.out.println();
    
              System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i+1) + "===", false));
    
            }
    
     
    
        }
    
    }

    运行程序得到实验结果:

     

    === Setup run 1 ===

    Classifier: weka.classifiers.trees.J48

    Dataset: RedWhiteWine-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.attribute.Remove-R13

    Folds: 10

    Seed: 1

     

    === 10-fold Cross-validation run 1===

    Correctly Classified Instances        6415               98.7379 %

    Incorrectly Classified Instances        82                1.2621 %

    Kappa statistic                          0.9658

    Mean absolute error                      0.0159

    Root mean squared error                  0.1109

    Relative absolute error                  4.2898 %

    Root relative squared error             25.7448 %

    Total Number of Instances             6497     

  • 相关阅读:
    Linux中history执行历史命令方法
    Linux中返回上一次目录
    Linux的vi编辑模式下常用快捷键
    [Android] TextView上同时显示图标和文字
    [Android] macOS的Android Studio快捷键
    [Android] 转-RxJava+MVP+Retrofit+Dagger2+Okhttp大杂烩
    [iOS] 测试设备解决自签名证书问题
    [macOS] keychain的跳坑之旅!git拉取的权限问题
    [PHP] swoole在daemonize模式下,chdir失效问题
    [macOS] macOS下,VirtualBox安装CentOS7.4, 搭建nginx, mysql, PHP5.6&PHP7.1
  • 原文地址:https://www.cnblogs.com/7899-89/p/3667330.html
Copyright © 2011-2022 走看看