zoukankan      html  css  js  c++  java
  • 每日一题 为了工作 2020 0507 第六十五题

    package data.bysj.tree;
    
    import org.apache.spark.Accumulator;
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.VoidFunction;
    import org.apache.spark.mllib.linalg.Vectors;
    import org.apache.spark.mllib.regression.LabeledPoint;
    import org.apache.spark.mllib.tree.RandomForest;
    import org.apache.spark.mllib.tree.model.RandomForestModel;
    import org.apache.spark.sql.SparkSession;
    import scala.Tuple2;
    
    import java.util.HashMap;
    import java.util.Map;
    
    /**
     *
     * @author 雪瞳
     * @Slogan 时钟尚且前行,人怎能就此止步!
     * @Function 
     *
     */
    public class RandomForestTrees {
        public static void main(String[] args) {
            String name = "forest";
            String master = "local[3]";
            SparkSession session = SparkSession.builder().master(master).appName(name).getOrCreate();
            JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
            jsc.setLogLevel("Error");
            JavaRDD<String> input = jsc.textFile("./save/rootData");
            JavaRDD<LabeledPoint> metaData = input.map(new Function<String, LabeledPoint>() {
                @Override
                public LabeledPoint call(String line) throws Exception {
                    //"2015-11-01 20:20:16"	1.85999330468501	1.22359452534749	2.51578969727773	-0.403918740333512	0.0149184125297424		0
                    String[] splits = line.split("	");
                    String label = splits[splits.length - 1];
    
                    double[] wd = new double[splits.length - 3];
                    for (int i = 0; i < wd.length; i++) {
                        wd[i] = Double.parseDouble(splits[i + 1]);
                    }
    
                    LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(label), Vectors.dense(wd));
                    return labeledPoint;
                }
            });
            
            double[] doubles = new double[]{0.7,0.3};
            //RDD<LabeledPoint> rdd = metaData.rdd();
            JavaRDD<LabeledPoint>[] metaDataSource = metaData.randomSplit(doubles, 10L);
            
            JavaRDD<LabeledPoint> traingData = metaDataSource[0];
            JavaRDD<LabeledPoint> testData = metaDataSource[1];
            
           
            int numClass = 2;
            
            Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<>();
           
            int numTrees = 3;
            
            String featureSubsetStrategy = "auto";
           
            String impurity = "entropy";
            
            int maxDepth = 4;
          
            int maxBins = 32;
            
            int seed = 1;
            RandomForestModel model = RandomForest.trainClassifier(
                    traingData,
                    numClass,
                    categoricalFeaturesInfo,
                    numTrees,
                    featureSubsetStrategy,
                    impurity,
                    maxDepth,
                    maxBins,
                    seed
                    );
            JavaRDD<Double> predictRdd = testData.map(new Function<LabeledPoint, Double>() {
                @Override
                public Double call(LabeledPoint labeledPoint) throws Exception {
                    double predict = model.predict(labeledPoint.features());
                    return predict;
                }
            });
            JavaPairRDD<Double, Double> resultRDD = predictRdd.zip(testData.map(new Function<LabeledPoint, Double>() {
                @Override
                public Double call(LabeledPoint labeledPoint) throws Exception {
                    return labeledPoint.label();
                }
            }));
            long count = resultRDD.count();
            Accumulator<Integer> accumulator = jsc.accumulator(0);
            resultRDD.foreach(new VoidFunction<Tuple2<Double, Double>>() {
                @Override
                public void call(Tuple2<Double, Double> tp) throws Exception {
                    Double label = tp._2();
                    Double predict = tp._1();
                    if (Double.compare(label,predict)==0){
                        accumulator.add(1);
                    }
                }
            });
            Integer value = accumulator.value();
            System.err.println("数目是:"+count);
            System.err.println("数目是:"+value);
            double rate = value / (double) count;
            System.err.println("正确率是:"+rate*100+"%");
            String path ="./save/model";
            double  stand = 80.00;
            if (Double.compare(rate,stand)<0){
    //            model.save(sc,path);
                System.out.println(model.toDebugString());
            }
    
        }
    }
    

      

     
  • 相关阅读:
    《JavaScript 源码分析》之 jquery.unobtrusive-ajax.js
    《JavaScript高级程序设计》读书笔记 2
    《JS设计模式笔记》构造函数和工厂模式创建对象
    《ES6基础教程》之 map、forEach、filter indexOf 用法
    《JS设计模式笔记》 5,适配器模式
    51Nod 1058 N的阶乘的长度
    ACM总结——2017区域赛网络赛总结
    ACM-ICPC国际大学生程序设计竞赛北京赛区(2017)网络赛 题目9 : Minimum
    hiho一下 第168周
    Fast Matrix Calculation HDU
  • 原文地址:https://www.cnblogs.com/walxt/p/12843902.html
Copyright © 2011-2022 走看看