zoukankan      html  css  js  c++  java
  • 调用weka模拟实现 “主动学习“ 算法

    主动学习:

    主动学习的过程:需要分类器与标记专家进行交互。一个典型的过程:

    (1)基于少量已标记样本构建模型

    (2)从未标记样本中选出信息量最大的样本,交给专家进行标记

    (3)将这些样本与之前样本进行融合,并构建模型

    (4)重复执行步骤(2)和步骤(3),直到stopping criterion(不存在未标记样本或其他条件)满足为止

    模拟思路:

    1. 将数据分为label 和 unlabel数据集

    2. 将 unlabel 分为100个一组,每组样本数组分别求出熵值,按照熵值排序,取前5个样本,添加到 label样本之中

    package demo;
    
    
    import java.io.FileReader;
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.Random;
    import weka.classifiers.Evaluation;
    import weka.classifiers.bayes.NaiveBayes;
    import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.converters.ConverterUtils.DataSource;
    
    //将测试用例,按照熵值进行排序
    class InstanceSort implements Comparable<InstanceSort>{
        public Instance instance;
        public double entropy;
        
        public InstanceSort( Instance instance, double entropy){
            this.instance = instance;
            this.entropy = entropy;
        }
        @Override
        public int compareTo(InstanceSort o) {
            // TODO Auto-generated method stub
            if (this.entropy < o.entropy){
                return 1;
            }else if ( this.entropy > o.entropy){
                return -1;
            }
            
            return 0;
        }
    }
    
    public class ActiveLearning {
    
        public static Instances getInstances( String fileName) throws Exception{
            Instances data = new Instances (new FileReader(fileName));
            data.setClassIndex(data.numAttributes()-1);
            return data;
        }
        
        //计算熵
        public static double computeEntropy(double predictValue){
            double entropy = 0.0;
            if ( 1-predictValue < 0.000000001d || predictValue < 0.000000001d){
                return 0;
            }else {
                return -predictValue*(Math.log(predictValue)/Math.log(2.0d))-(1-predictValue)*(Math.log(1-predictValue)/Math.log(2.0d));
            }
        }
        
        public static void classify(Instances train, Instances test) throws Exception{
            NaiveBayes classifier = new NaiveBayes();
            //训练模型
            classifier.buildClassifier(train);
            
            //评价模型
            Evaluation eval = new Evaluation(test);
            eval.evaluateModel(classifier, test);
            System.out.println(eval.toClassDetailsString());
        }
        
        //不确定采样
        public static Instances uncertaintySample(Instances labeled, Instances unlabeled, int start, int end) throws Exception{
            //用有标签的先训练模型
            NaiveBayes classifier = new NaiveBayes();
            classifier.buildClassifier(labeled);
            //按照熵进行排序
            ArrayList <InstanceSort> l = new ArrayList<InstanceSort>();
            
            for (int i = start; i < end; i++) {
                double result = classifier.classifyInstance(unlabeled.instance(i));
                double entropy =  computeEntropy (result);
                InstanceSort is = new InstanceSort(unlabeled.instance(i), entropy);
                l.add(is);
            }
            //按照熵值进行排序
            Collections.sort(l);
            
            DataSource source = new DataSource("NASA//pc1.arff");
            Instances A = source.getDataSet();
            Instances chosenInstances = new Instances(A, 0);
            //每100个里面选择5个熵值最小的实例
            for(int i = 0; i < 5; i++){
                chosenInstances.add(l.get(i).instance);
            }
            
            return chosenInstances;
        }
        
        //采样
        public static void sample( Instances instances, Instances test) throws Exception{
            Random rand = new Random(1023);
            instances.randomize(rand);
            instances.stratify(10);
            Instances unlabeled = instances.trainCV(10, 0);
            Instances labeled = instances.testCV(10, 0);
            
            int iterations = unlabeled.numInstances() / 100 +1;
            
            for ( int i=0; i< iterations-1 ; i++){
                //每100个里面选择5个熵值最小的实例
                //100个一组
                Instances resultInstances = uncertaintySample(labeled, unlabeled, i*100, (i+1)*100);
                for (int j = 0; j < resultInstances.numInstances(); j++){
                    labeled.add(resultInstances.instance(j));
                }
                classify(labeled, test);
            }
            
            Instances resultInstances = uncertaintySample(labeled, unlabeled, (iterations-1)*100, unlabeled.numInstances());
            
            for (int j = 0; j < resultInstances.numInstances(); j++){
                labeled.add(resultInstances.instance(j));
            }
            
            classify(labeled, test);    
        
        }
        
        public static void main(String[] args)  throws Exception{
            // TODO Auto-generated method stub
            Instances instances = getInstances("NASA//pc1.arff");
            
            //10-fold cross validation
            Random rand = new Random(1023);
            instances.randomize(rand);
            instances.stratify(10);
            Instances train = instances.trainCV(10, 0);
            Instances test = instances.testCV(10, 0);
    //        System.out.println(train.numInstances());
    //        System.out.println(test.numInstances());
            
            sample(train,test);
    
        }
    
    }
  • 相关阅读:
    JavaScript 为字符串添加样式 【每日一段代码80】
    JavaScript replace()方法 【每日一段代码83】
    JavaScript for in 遍历数组 【每日一段代码89】
    JavaScript 创建用于对象的模板【每日一段代码78】
    html5 css3 新元素简单页面布局
    JavaScript Array() 数组 【每日一段代码88】
    JavaScript toUTCString() 方法 【每日一段代码86】
    位运算
    POJ 3259 Wormholes
    POJ 3169 Layout
  • 原文地址:https://www.cnblogs.com/douzujun/p/8410939.html
Copyright © 2011-2022 走看看