zoukankan      html  css  js  c++  java
  • 关于spark的mllib学习总结(Java版)

    本篇博客主要讲述如何利用spark的mliib构建机器学习模型并预测新的数据,具体的流程如下图所示:

     

    加载数据 对于数据的加载或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的数据是采用spark中提供的数据sample_libsvm_data.txt,其有一百个数据样本,658个特征。具体的数据形式如图所示: 

    加载libsvm 

    JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();

    LabeledPoint数据类型是对应与libsvmfile格式文件, 具体格式为: Lable(double类型),vector(Vector类型) 转化dataFrame数据类型 

    JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());
    StructType schema = new StructType(new StructField[]{
                        new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                        new StructField("features", new VectorUDT(), false, Metadata.empty()),
            });
    SQLContext jsql = new SQLContext(sc);
    DataFrame df = jsql.createDataFrame(jrow, schema);

    DataFrame:DataFrame是一个以命名列方式组织的分布式数据集。在概念上,它跟关系型数据库中的一张表或者1个Python(或者R)中的data frame一样,但是比他们更优化。DataFrame可以根据结构化的数据文件、hive表、外部数据库或者已经存在的RDD构造。 SQLContext:spark sql所有功能的入口是SQLContext类,或者SQLContext的子类。为了创建一个基本的SQLContext,需要一个SparkContext。 特征提取 特征归一化处理 

    StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);
    DataFrame scalerDF = scaler.fit(df).transform(df);
    scaler.save(this.scalerModelPath);

    利用卡方统计做特征提取 

    ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");
    ChiSqSelectorModel chiModel = selector.fit(scalerDF);
    DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");
    chiModel.save(this.featureSelectedModelPath);

    训练机器学习模型(以SVM为例)

    //转化为LabeledPoint数据类型, 训练模型
    JavaRDD<Row> selectedrows = selectedDF.javaRDD();
    JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());
    
    //训练SVM模型, 并保存
    int numIteration = 200;
    SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);
    model.clearThreshold();
    model.save(sc, this.mlModelPath);
    
    // LabeledPoint数据类型转化为Row
    static class LabeledPointToRow implements Function<LabeledPoint, Row> {
    
            public Row call(LabeledPoint p) throws Exception {
                double label = p.label();
                Vector vector = p.features();
                return RowFactory.create(label, vector);
            }
        }
    
    //Rows数据类型转化为LabeledPoint
    static class RowToLabel implements Function<Row, LabeledPoint> {
    
            public LabeledPoint call(Row r) throws Exception {
                Vector features = r.getAs(1);
                double label = r.getDouble(0);
                return new LabeledPoint(label, features);
            }
        }

    测试新的样本 测试新的样本前,需要将样本做数据的转化和特征提取的工作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下:

    //初始化spark
    SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
    conf.set("spark.testing.memory", "2147480000");
    SparkContext sc = new SparkContext(conf);
    
    //加载测试数据
    JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();
    
    //转化DataFrame数据类型
    JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());
            StructType schema = new StructType(new StructField[]{
                        new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                        new StructField("features", new VectorUDT(), false, Metadata.empty()),
            });
    SQLContext jsql = new SQLContext(sc);
    DataFrame df = jsql.createDataFrame(jrow, schema);
    
            //数据规范化
    StandardScaler scaler = StandardScaler.load(this.scalerModelPath);
    DataFrame scalerDF = scaler.fit(df).transform(df);
    
            //特征选取
    ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);
    DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");

    测试数据集

    SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);
    JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;
    predictResult.collect();
    
    static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
            SVMModel model;
            public Prediction(SVMModel model){
                this.model = model;
            }
            public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
                Double score = model.predict(p.features());
                return new Tuple2<Double , Double>(score, p.label());
            }
        }

    计算准确率

    double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();
    System.out.println(accuracy);
    
    static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {
            public Boolean call(Tuple2<Double, Double> t) throws Exception {
                double score = t._1();
                double label = t._2();
                System.out.print("score:" + score + ", label:"+ label);
                if(score >= 0.0 && label >= 0.0) return true;
                else if(score < 0.0 && label < 0.0) return true;
                else return false;
            }
        }
  • 相关阅读:
    MFC tab页面中获到其它页面的数据
    sqlite数据库中"Select * From XXX能查到数据,但是Select DISTINCT group From xxx Order By group却查不出来
    关闭程序出现崩溃(exe 已触发了一个断点及未加载ucrtbased.pdb)
    springboot 通用Mapper使用
    springBoot 发布war包
    springCloud Zuul网关
    springboot hystrix turbine 聚合监控
    springBoot Feign Hystrix Dashboard
    springBoot Ribbon Hystrix Dashboard
    springBoot Feign Hystrix
  • 原文地址:https://www.cnblogs.com/itboys/p/9692594.html
Copyright © 2011-2022 走看看