zoukankan      html  css  js  c++  java
  • Spark MLlib机器学习(一)——决策树

    决策树模型,适用于分类、回归。 
    简单地理解决策树呢,就是通过不断地设置新的条件标准对当前的数据进行划分,最后以实现把原始的杂乱的所有数据分类。

    就像下面这个图,如果输入是一大堆追求一个妹子的汉子,妹子内心里有个筛子,最后菇凉也就决定了和谁约(举栗而已哦,不代表什么~大家理解原理重要~~)

    训练数据:

    0,32 帅 收入中等 不是公务员
    1,25 帅 收入中等 是公务员
    0,25 帅 收入中等 不是公务员
    1,29 帅 收入中等 是公务员
    1,24 帅 收入高 不是公务员
    0,31 帅 收入高 不是公务员
    0,35 帅 收入中等 是公务员
    0,30 不帅 收入中等 不是公务员
    0,31 帅 收入高 不是公务员
    1,30 帅 收入中等 是公务员
    1,21 帅 收入高 不是公务员
    0,21 帅 收入中等 不是公务员
    1,21 帅 收入中等 是公务员
    0,29 不帅 收入中等 是公务员
    0,29 帅 收入底 是公务员
    0,29 不帅 收入底 是公务员
    1,30 帅 收入高 不是公务员

    测试数据:

    0,32 帅 收入中等 不是公务员
    1,27 帅 收入高 是公务员
    1,29 帅 收入高 不是公务员
    1,25 帅 收入中等 是公务员
    0,23 不帅 收入中等 是公务员

    代码实现:

    package com.test;
    
    import java.util.Arrays;
    import java.util.HashMap;
    import java.util.Map;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.PairFunction;
    import org.apache.spark.api.java.function.VoidFunction;
    import org.apache.spark.mllib.feature.HashingTF;
    import org.apache.spark.mllib.linalg.Vector;
    import org.apache.spark.mllib.linalg.Vectors;
    import org.apache.spark.mllib.regression.LabeledPoint;
    import org.apache.spark.mllib.tree.DecisionTree;
    import org.apache.spark.mllib.tree.model.DecisionTreeModel;
    import org.apache.spark.sql.SparkSession;
    
    import scala.Tuple2;
    
    public class DecisionTreeTest2 {
    
    	public static void main(String[] args) {
    
    		//SparkConf conf = new SparkConf().setMaster("local").setAppName("DecisionTreeTest").config("spark.sql.warehouse.dir","file:///D://test").getOrCreate() ;		
    		SparkSession spark = SparkSession.builder().master("local[5]")
    				.appName("DecisionTreeTest")
    				.config("spark.sql.warehouse.dir", "/user/hive/warehouse/").enableHiveSupport()
    				.getOrCreate();
    		
    		JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
    		JavaRDD<String> lines = jsc.textFile("C://tree3.txt");
    		
    		final HashingTF tf = new HashingTF(10000);
    		
    		JavaRDD<LabeledPoint> transdata = lines.map(new Function<String, LabeledPoint>() {
    			private static final long serialVersionUID = 1L;
    
    			@Override
    			public LabeledPoint call(String str) throws Exception {
    				String[] t1 = str.split(",");
    				String[] t2 = t1[1].split(" ");
    				LabeledPoint lab = new LabeledPoint(Double.parseDouble(t1[0]),tf.transform(Arrays.asList(t2)));
    				return lab;
    			}
    		});
    		// 设置决策树参数,训练模型
    		Integer numClasses = 3;
    		Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
    		String impurity = "gini";
    		Integer maxDepth = 5;
    		Integer maxBins = 32;
    		final DecisionTreeModel tree_model = DecisionTree.trainClassifier(transdata, numClasses,
    				categoricalFeaturesInfo, impurity, maxDepth, maxBins);
    		System.out.println("决策树模型:");
    		System.out.println(tree_model.toDebugString());
    		// 保存模型
    		tree_model.save(jsc.sc(), "C://DecisionTreeModel");
    
    		// 未处理数据,带入模型处理
    		JavaRDD<String> testLines = jsc.textFile("C://tree4.txt");
    		JavaPairRDD<String, String> res = testLines.mapToPair(new PairFunction<String, String, String>() {
    			private static final long serialVersionUID = 1L;
    
    			@Override
    			public Tuple2<String, String> call(String line) throws Exception {
    				String[] t2 = line.split(",")[1].split(" ");
    				Vector v = tf.transform(Arrays.asList(t2));
    				double res = tree_model.predict(v);
    				return new Tuple2<String, String>(line, Double.toString(res));
    			}
    		}).cache();
    		// 打印结果
    		res.foreach(new VoidFunction<Tuple2<String, String>>() {
    			private static final long serialVersionUID = 1L;
    
    			@Override
    			public void call(Tuple2<String, String> a) throws Exception {
    				System.out.println(a._1 + " : " + a._2);
    			}
    		});
    		// 将结果保存在本地
    		res.saveAsTextFile("C://res");
    
    	}
    
    }

    测试结果:

    0,32 帅 收入中等 不是公务员 : 0.0
    1,27 帅 收入高 是公务员 : 1.0
    1,29 帅 收入高 不是公务员 : 1.0
    1,25 帅 收入中等 是公务员 : 1.0
    0,23 不帅 收入中等 是公务员 : 0.0

  • 相关阅读:
    Cglib的动态代理
    idea中隐藏.idea文件夹和.iml文件
    JDBC工具类创建及使用
    JDBC的配置及使用入门
    mybatis的入门
    动态代理的具体实现
    【Flask】WTForms基本使用
    【Flask】Flask-Migrate基本使用
    【Flask】Flask-Sqlalchemy使用笔记
    【Flask】Sqlalchemy 子查询
  • 原文地址:https://www.cnblogs.com/gmhappy/p/9472428.html
Copyright © 2011-2022 走看看