zoukankan      html  css  js  c++  java
  • 调用weka模拟实现 “主动学习“ 算法

    主动学习:

    主动学习的过程:需要分类器与标记专家进行交互。一个典型的过程:

    (1)基于少量已标记样本构建模型

    (2)从未标记样本中选出信息量最大的样本,交给专家进行标记

    (3)将这些样本与之前样本进行融合,并构建模型

    (4)重复执行步骤(2)和步骤(3),直到stopping criterion(不存在未标记样本或其他条件)满足为止

    模拟思路:

    1. 将数据分为label 和 unlabel数据集

    2. 将 unlabel 分为100个一组,每组样本数组分别求出熵值,按照熵值排序,取前5个样本,添加到 label样本之中

    package demo;
    
    
    import java.io.FileReader;
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.Random;
    import weka.classifiers.Evaluation;
    import weka.classifiers.bayes.NaiveBayes;
    import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.converters.ConverterUtils.DataSource;
    
    //将测试用例,按照熵值进行排序
    class InstanceSort implements Comparable<InstanceSort>{
        public Instance instance;
        public double entropy;
        
        public InstanceSort( Instance instance, double entropy){
            this.instance = instance;
            this.entropy = entropy;
        }
        @Override
        public int compareTo(InstanceSort o) {
            // TODO Auto-generated method stub
            if (this.entropy < o.entropy){
                return 1;
            }else if ( this.entropy > o.entropy){
                return -1;
            }
            
            return 0;
        }
    }
    
    public class ActiveLearning {
    
        public static Instances getInstances( String fileName) throws Exception{
            Instances data = new Instances (new FileReader(fileName));
            data.setClassIndex(data.numAttributes()-1);
            return data;
        }
        
        //计算熵
        public static double computeEntropy(double predictValue){
            double entropy = 0.0;
            if ( 1-predictValue < 0.000000001d || predictValue < 0.000000001d){
                return 0;
            }else {
                return -predictValue*(Math.log(predictValue)/Math.log(2.0d))-(1-predictValue)*(Math.log(1-predictValue)/Math.log(2.0d));
            }
        }
        
        public static void classify(Instances train, Instances test) throws Exception{
            NaiveBayes classifier = new NaiveBayes();
            //训练模型
            classifier.buildClassifier(train);
            
            //评价模型
            Evaluation eval = new Evaluation(test);
            eval.evaluateModel(classifier, test);
            System.out.println(eval.toClassDetailsString());
        }
        
        //不确定采样
        public static Instances uncertaintySample(Instances labeled, Instances unlabeled, int start, int end) throws Exception{
            //用有标签的先训练模型
            NaiveBayes classifier = new NaiveBayes();
            classifier.buildClassifier(labeled);
            //按照熵进行排序
            ArrayList <InstanceSort> l = new ArrayList<InstanceSort>();
            
            for (int i = start; i < end; i++) {
                double result = classifier.classifyInstance(unlabeled.instance(i));
                double entropy =  computeEntropy (result);
                InstanceSort is = new InstanceSort(unlabeled.instance(i), entropy);
                l.add(is);
            }
            //按照熵值进行排序
            Collections.sort(l);
            
            DataSource source = new DataSource("NASA//pc1.arff");
            Instances A = source.getDataSet();
            Instances chosenInstances = new Instances(A, 0);
            //每100个里面选择5个熵值最小的实例
            for(int i = 0; i < 5; i++){
                chosenInstances.add(l.get(i).instance);
            }
            
            return chosenInstances;
        }
        
        //采样
        public static void sample( Instances instances, Instances test) throws Exception{
            Random rand = new Random(1023);
            instances.randomize(rand);
            instances.stratify(10);
            Instances unlabeled = instances.trainCV(10, 0);
            Instances labeled = instances.testCV(10, 0);
            
            int iterations = unlabeled.numInstances() / 100 +1;
            
            for ( int i=0; i< iterations-1 ; i++){
                //每100个里面选择5个熵值最小的实例
                //100个一组
                Instances resultInstances = uncertaintySample(labeled, unlabeled, i*100, (i+1)*100);
                for (int j = 0; j < resultInstances.numInstances(); j++){
                    labeled.add(resultInstances.instance(j));
                }
                classify(labeled, test);
            }
            
            Instances resultInstances = uncertaintySample(labeled, unlabeled, (iterations-1)*100, unlabeled.numInstances());
            
            for (int j = 0; j < resultInstances.numInstances(); j++){
                labeled.add(resultInstances.instance(j));
            }
            
            classify(labeled, test);    
        
        }
        
        public static void main(String[] args)  throws Exception{
            // TODO Auto-generated method stub
            Instances instances = getInstances("NASA//pc1.arff");
            
            //10-fold cross validation
            Random rand = new Random(1023);
            instances.randomize(rand);
            instances.stratify(10);
            Instances train = instances.trainCV(10, 0);
            Instances test = instances.testCV(10, 0);
    //        System.out.println(train.numInstances());
    //        System.out.println(test.numInstances());
            
            sample(train,test);
    
        }
    
    }
  • 相关阅读:
    bat脚本运行py文件失败(一闪而过)
    python 将日期戳(五位数时间)转换为标准时间
    Pandas 如何通过获取双(多)重索引获取指定行DataFrame数据
    Pandas 横向合并DataFrame数据
    Pandas 删除指定列中为NaN的行
    git 解决push报错:[rejected] master -> master (fetch first) error: failed to push some refs to
    pandas删除包含指定内容的行
    python项目环境的导出、导入
    pandas 修改列名
    Javascript 异步编程的4种方法
  • 原文地址:https://www.cnblogs.com/douzujun/p/8410939.html
Copyright © 2011-2022 走看看