zoukankan      html  css  js  c++  java
  • 调用weka模拟实现 “主动学习“ 算法

    主动学习:

    主动学习的过程:需要分类器与标记专家进行交互。一个典型的过程:

    (1)基于少量已标记样本构建模型

    (2)从未标记样本中选出信息量最大的样本,交给专家进行标记

    (3)将这些样本与之前样本进行融合,并构建模型

    (4)重复执行步骤(2)和步骤(3),直到stopping criterion(不存在未标记样本或其他条件)满足为止

    模拟思路:

    1. 将数据分为label 和 unlabel数据集

    2. 将 unlabel 分为100个一组,每组样本数组分别求出熵值,按照熵值排序,取前5个样本,添加到 label样本之中

    package demo;
    
    
    import java.io.FileReader;
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.Random;
    import weka.classifiers.Evaluation;
    import weka.classifiers.bayes.NaiveBayes;
    import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.converters.ConverterUtils.DataSource;
    
    //将测试用例,按照熵值进行排序
    class InstanceSort implements Comparable<InstanceSort>{
        public Instance instance;
        public double entropy;
        
        public InstanceSort( Instance instance, double entropy){
            this.instance = instance;
            this.entropy = entropy;
        }
        @Override
        public int compareTo(InstanceSort o) {
            // TODO Auto-generated method stub
            if (this.entropy < o.entropy){
                return 1;
            }else if ( this.entropy > o.entropy){
                return -1;
            }
            
            return 0;
        }
    }
    
    public class ActiveLearning {
    
        public static Instances getInstances( String fileName) throws Exception{
            Instances data = new Instances (new FileReader(fileName));
            data.setClassIndex(data.numAttributes()-1);
            return data;
        }
        
        //计算熵
        public static double computeEntropy(double predictValue){
            double entropy = 0.0;
            if ( 1-predictValue < 0.000000001d || predictValue < 0.000000001d){
                return 0;
            }else {
                return -predictValue*(Math.log(predictValue)/Math.log(2.0d))-(1-predictValue)*(Math.log(1-predictValue)/Math.log(2.0d));
            }
        }
        
        public static void classify(Instances train, Instances test) throws Exception{
            NaiveBayes classifier = new NaiveBayes();
            //训练模型
            classifier.buildClassifier(train);
            
            //评价模型
            Evaluation eval = new Evaluation(test);
            eval.evaluateModel(classifier, test);
            System.out.println(eval.toClassDetailsString());
        }
        
        //不确定采样
        public static Instances uncertaintySample(Instances labeled, Instances unlabeled, int start, int end) throws Exception{
            //用有标签的先训练模型
            NaiveBayes classifier = new NaiveBayes();
            classifier.buildClassifier(labeled);
            //按照熵进行排序
            ArrayList <InstanceSort> l = new ArrayList<InstanceSort>();
            
            for (int i = start; i < end; i++) {
                double result = classifier.classifyInstance(unlabeled.instance(i));
                double entropy =  computeEntropy (result);
                InstanceSort is = new InstanceSort(unlabeled.instance(i), entropy);
                l.add(is);
            }
            //按照熵值进行排序
            Collections.sort(l);
            
            DataSource source = new DataSource("NASA//pc1.arff");
            Instances A = source.getDataSet();
            Instances chosenInstances = new Instances(A, 0);
            //每100个里面选择5个熵值最小的实例
            for(int i = 0; i < 5; i++){
                chosenInstances.add(l.get(i).instance);
            }
            
            return chosenInstances;
        }
        
        //采样
        public static void sample( Instances instances, Instances test) throws Exception{
            Random rand = new Random(1023);
            instances.randomize(rand);
            instances.stratify(10);
            Instances unlabeled = instances.trainCV(10, 0);
            Instances labeled = instances.testCV(10, 0);
            
            int iterations = unlabeled.numInstances() / 100 +1;
            
            for ( int i=0; i< iterations-1 ; i++){
                //每100个里面选择5个熵值最小的实例
                //100个一组
                Instances resultInstances = uncertaintySample(labeled, unlabeled, i*100, (i+1)*100);
                for (int j = 0; j < resultInstances.numInstances(); j++){
                    labeled.add(resultInstances.instance(j));
                }
                classify(labeled, test);
            }
            
            Instances resultInstances = uncertaintySample(labeled, unlabeled, (iterations-1)*100, unlabeled.numInstances());
            
            for (int j = 0; j < resultInstances.numInstances(); j++){
                labeled.add(resultInstances.instance(j));
            }
            
            classify(labeled, test);    
        
        }
        
        public static void main(String[] args)  throws Exception{
            // TODO Auto-generated method stub
            Instances instances = getInstances("NASA//pc1.arff");
            
            //10-fold cross validation
            Random rand = new Random(1023);
            instances.randomize(rand);
            instances.stratify(10);
            Instances train = instances.trainCV(10, 0);
            Instances test = instances.testCV(10, 0);
    //        System.out.println(train.numInstances());
    //        System.out.println(test.numInstances());
            
            sample(train,test);
    
        }
    
    }
  • 相关阅读:
    计算机综合面试题总结
    MySQL入门很简单: 13 数据备份和还原
    MySQL入门很简单: 12 MYSQL 用户管理
    MySQL入门很简单: 11 mysql函数
    页面即使加过了移除监听事件,但是到新页面后事件任然存在
    git命令大全
    document.documentElement.scrollTop指定位置失效解决办法
    vue做商品选择如何保持样式
    vue父组件向子组件传对象,不实时更新解决
    vue 遮罩层阻止默认滚动事件
  • 原文地址:https://www.cnblogs.com/douzujun/p/8410939.html
Copyright © 2011-2022 走看看