zoukankan      html  css  js  c++  java
  • 【weka】分类,cross-validation,数据

    一、分类classifier

      如何利用weka里的类对数据集进行分类,要对数据集进行分类,第一步要指定数据集中哪一列做为类别,如果这一步忘记了(事实上经常会忘记)会出现“Class index is negative (not set)!”这个错误,设置某一列为类别用Instances类的成员方法setClassIndex,要设置最后一列为类别则可以用Instances类的numAttributes()成员方法得到属性的个数再减1。

      然后选择分类器,比较常用的分类器有J48,NaiveBayes,SMO(LibSVM有Java版的,可以在weka中使用,但要设置路径),训练分类器使用J48的buildClassifier(注意J48还有别的分类器它们都继承自Classifier类,使用方法都差不多),分类数据用J48类中的classifyInstance方法,例中使用的数据集为contact-lenses.arff,分类结果为2.0,结果为2.0的原因是:首先用文本编辑器打开数据集,有一行为@attribute contact-lenses {soft, hard, none},而第一个样本为young, myope, no, reduced, none,最后一列为类别,也就是contact-lences为类别,第一个样本的类别为none,在属性说明中none为第二个所以为2.0(从0开始数)。

    二、评估Evaluation

      Evaluation类,这次只讲一下最简单的用法,首先初始化一个Evaluation对象,Evaluation类没有无参的构造函数,一般用Instances对象作为构造函数的参数。

           如果没有分开训练集和测试集,可以使用Cross Validation方法EvaluationcrossValidateModel方法的四个参数分别为,第一个是分类器,第二个是在某个数据集上评价的数据集,第三个参数是交叉检验的次数(10是比较常见的),第四个是一个随机数对象。

           如果有训练集和测试集,可以使用Evaluation 类中的evaluateModel方法,方法中的参数为:第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集。例中我为了简单用训练集再次做为测试集,希望大家不会糊涂。

           提醒大家一下,使用crossValidateModel时,分类器不需要先训练,这其实也应该是常识了。

           Evaluation中提供了多种输出方法,大家如果用过weka软件,会发现方法输出结果与软件中某个显示结果的是对应的。例中的三个方法toClassDetailsStringtoSummaryStringtoMatrixString比较常用。

    三、特征选择AttributeSelection

      用AttributeSelection进行特征选择,它需要设置3个方面,第一:对属性评价的类(自己到Weka软件里看一下,英文Attribute Evaluator),第二:搜索的方式(自己到Weka软件里看一下,英文Search Method),第三:就是你要进行特征选择的数据集了。最后调用Filter的静态方法userFilter,感觉写的都是废话,一看代码就明白了。唯一值得一说的也就是别把AttributeSelection的包加错了,代码旁边有注释。

    package org.ml;
    
    import java.io.BufferedReader;
    import java.io.FileNotFoundException;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.Random;
    
    import weka.attributeSelection.CfsSubsetEval;
    import weka.attributeSelection.GreedyStepwise;
    import weka.classifiers.Classifier;
    import weka.classifiers.Evaluation;
    import weka.classifiers.meta.AttributeSelectedClassifier;
    import weka.classifiers.trees.J48;
    import weka.core.Instances;
    import weka.filters.Filter;
    import weka.filters.supervised.attribute.AttributeSelection;
    
    public class Test {
    
        public static Instances getFileInstances(String fileName)
                throws FileNotFoundException, IOException {
            Instances m_Instances = new Instances(new BufferedReader(
                    new FileReader(fileName)));
            m_Instances.setClassIndex(m_Instances.numAttributes() - 1);
            return m_Instances;
        }
    
        public static Evaluation crossValidation(Instances m_Instances,
                Classifier classifier, int numFolds) throws Exception {
            Evaluation evaluation = new Evaluation(m_Instances);
            evaluation.crossValidateModel(classifier, m_Instances, numFolds,
                    new Random(1));
            return evaluation;
        }
        
        public static Evaluation evaluateTestData(Instances m_Instances, Classifier classifier) throws Exception {
            int split = (int) (m_Instances.numInstances() * 0.6);
            Instances traindata = new Instances(m_Instances, 0, split);
            Instances testdata = new Instances(m_Instances, split, m_Instances.numInstances() - split);
            classifier.buildClassifier(traindata);
            //下面一行是m_Instances,或traindata,或testdata都没关系,因为Evaluation构造方法要的只是instance的结构,比如属性
            Evaluation evaluation = new Evaluation(m_Instances);
            evaluation.evaluateModel(classifier, testdata);
            return evaluation;
        }
        
        public static Instances selectAttrUseFilter(Instances m_Instances) throws Exception {
            AttributeSelection filter = new AttributeSelection();
            filter.setEvaluator(new CfsSubsetEval());
            filter.setSearch(new GreedyStepwise());
            filter.setInputFormat(m_Instances);
            return Filter.useFilter(m_Instances, filter);
        }
        
        public static void selectAttrUseMC(Instances m_Instances, Classifier base) throws Exception {
            AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
            classifier.setClassifier(base);
            classifier.setEvaluator(new CfsSubsetEval());
            classifier.setSearch(new GreedyStepwise());
            Evaluation evaluation = new Evaluation(m_Instances);
            evaluation.crossValidateModel(classifier, m_Instances, 10, new Random(1));
            System.out.println(evaluation.toSummaryString());
        }
        
        public static void printEvalDetail(Evaluation evaluation) throws Exception {
            System.out.println(evaluation.toClassDetailsString());
            System.out.println(evaluation.toSummaryString());
            System.out.println(evaluation.toMatrixString());
        }
    
        public static void main(String[] args) throws Exception {
            
            Instances data = getFileInstances("C:\Program Files\Weka-3-7\data\soybean.arff");
            //交叉验证
            Evaluation crossEvaluation = crossValidation(data, new J48(), 10);
            printEvalDetail(crossEvaluation);
            
            System.out.println("=====================================");
            //一般分类器分类,部分数据用于train,部分用于test
            Evaluation testEvaluation = evaluateTestData(data, new J48());
            printEvalDetail(testEvaluation);
            
            System.out.println("=====================================");
            //特征筛选
            Instances newData = selectAttrUseFilter(data);
            System.out.println("Oral data:" + data.numAttributes());
            System.out.println("selected data:" + newData.numAttributes());
            testEvaluation = evaluateTestData(newData, new J48());
            printEvalDetail(testEvaluation);
            
            System.out.println("=====================================");
            selectAttrUseMC(data, new J48());
            
    
    //        System.out.println("=====================================");
    //        J48 classifer = new J48();
    //        classifer.buildClassifier(data);
    //        for (int i = 0; i < data.numInstances(); i++) {
    //        //输出每个样例被分到的类别,如果是二分,分别表示为0和1
    // System.out.println(data.instance(i) + " === " + classifer.classifyInstance(data.instance(i))); // } } }
  • 相关阅读:
    C# 语言基础(++和--运算)
    Android RSA加密解密
    laravel redis
    larave5.1l队列
    shell更改目录编码
    Linux的权限说明
    MySQL主从架构之Master-Master互为主备
    php curl流方式远程下载大文件
    php session 跨页失效问题
    mysql中文字段转拼音首字母,以及中文拼音模糊查询
  • 原文地址:https://www.cnblogs.com/549294286/p/3299377.html
Copyright © 2011-2022 走看看