zoukankan      html  css  js  c++  java
  • sparkML原始数据转换成label-features方法

    数据1:kaggle-旧金山犯罪分类数据
    格式如下:
    Dates,Category,Descript,DayOfWeek,PdDistrict,Resolution,Address,X,Y
    2015-05-13 23:53:00,WARRANTS,WARRANT ARREST,Wednesday,NORTHERN,"ARREST, BOOKED",OAK ST / LAGUNA ST,-122.425891675136,37.7745985956747
    2015-05-13 23:53:00,OTHER OFFENSES,TRAFFIC VIOLATION ARREST,Wednesday,NORTHERN,"ARREST, BOOKED",OAK ST / LAGUNA ST,-122.425891675136,37.7745985956747
    2015-05-13 23:33:00,OTHER OFFENSES,TRAFFIC VIOLATION ARREST,Wednesday,NORTHERN,"ARREST, BOOKED",VANNESS AV / GREENWICH ST,-122.42436302145,37.8004143219856
    2015-05-13 23:30:00,LARCENY/THEFT,GRAND THEFT FROM LOCKED AUTO,Wednesday,NORTHERN,NONE,1500 Block of LOMBARD ST,-122.42699532676599,37.80087263276921
    2015-05-13 23:30:00,LARCENY/THEFT,GRAND THEFT FROM LOCKED AUTO,Wednesday,PARK,NONE,100 Block of BRODERICK ST,-122.438737622757,37.771541172057795
    2015-05-13 23:30:00,LARCENY/THEFT,GRAND THEFT FROM UNLOCKED AUTO,Wednesday,INGLESIDE,NONE,0 Block of TEDDY AV,-122.40325236121201,37.713430704116
    2015-05-13 23:30:00,VEHICLE THEFT,STOLEN AUTOMOBILE,Wednesday,INGLESIDE,NONE,AVALON AV / PERU AV,-122.423326976668,37.7251380403778
    2015-05-13 23:30:00,VEHICLE THEFT,STOLEN AUTOMOBILE,Wednesday,BAYVIEW,NONE,KIRKWOOD AV / DONAHUE ST,-122.371274317441,37.7275640719518
    2015-05-13 23:00:00,LARCENY/THEFT,GRAND THEFT FROM LOCKED AUTO,Wednesday,RICHMOND,NONE,600 Block of 47TH AV,-122.508194031117,37.776601260681204
    
    测试代码:
    
        public static void main(String[] args) {
    
            SparkSession spark = SparkSession.builder().enableHiveSupport()
                    .getOrCreate();
            Dataset<Row> dataset = spark
                    .read()
                    .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat")
                    .option("header", true)
                    .option("inferSchema", true)
                    .option("delimiter", ",")
                    .load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/SanFranciscoCrime/document/kaggle-旧金山犯罪分类/train-new.csv")
                    .persist();
    
            DataPreProcess(dataset);
    
        }
    
       //此函数包含StringIndexer,OneHotEncoder,VectorAssembler,VectorIndexer数据转换方法
        public static Dataset<Row> DataPreProcess(Dataset<Row> data) {
    
            //Dataset<Row> df = data.selectExpr("cast(Dates as String) ,DayOfWeek,PdDistrict,Category".split(","));
    
            Dataset<Row> df = data.select(data.col("Dates").cast("String").alias("Dates"),data.col("DayOfWeek").alias("DayOfWeek"),data.col("PdDistrict"),data.col("Category"));
            df.printSchema();
            // 重新索引标签值
    
            SparkLog.info(data.select("Category").distinct().count());
    
            //将非数字类型标签转换成数字类型,按照标签去重的个数n,编号0~n,相同标签的多行记录转换后的数字标签编号相同
            //这个适合所有非数字且不连续的有限类别数据编号,不仅仅是只能编号标签
            StringIndexerModel labelIndexer = new StringIndexer()
                    .setInputCol("Category").setOutputCol("label").fit(df);
    
            StringIndexerModel DateIndexer = new StringIndexer()
                    .setInputCol("Dates").setOutputCol("DatesNum").fit(df);
    
            StringIndexerModel DayOfWeekIndexer = new StringIndexer()
                    .setInputCol("DayOfWeek").setOutputCol("dfNum").fit(df);
    
            StringIndexerModel PdDistrictIndexer = new StringIndexer()
                    .setInputCol("PdDistrict").setOutputCol("pdNum").fit(df);
    
            /*独热编码将类别特征(离散的,已经转换为数字编号形式(这个是必须的,否则会报错),
            映射成独热编码,生成的是一个稀疏向量
            比如字符串"abcab"的映射规则:去重后的特征个数n即为稀疏向量的维数,而数字编号代
            表该特征对应的向量中非0值的下标,最后生成0-1编码的向量
            a  1 0 0
            b  0 1 0
            c  0 0 1
            a  1 0 0
            b  0 1 0
            */
            
            //OneHotEncoder不需要fit
            OneHotEncoder encoder = new OneHotEncoder().setInputCol("dfNum")
                    .setOutputCol("dfvec")
                    .setDropLast(false);  // 设置最后一个是否包含
    
            OneHotEncoder encoder1 = new OneHotEncoder().setInputCol("pdNum")
                    .setOutputCol("pdvec")
                    .setDropLast(false);// 设置最后一个是否包含
    
            OneHotEncoder encoder2 = new OneHotEncoder().setInputCol("DatesNum")
                    .setOutputCol("Datesvec")
                    .setDropLast(false);// 设置最后一个是否包含
    
            //将多个列拼接成一个向量,列的类型可以是向量
            VectorAssembler assembler = new VectorAssembler().setInputCols(
                    "Datesvec,dfvec,pdvec".split(",")).setOutputCol("features");
    
            // Dataset<Row> assembledFeatures = assembler.transform(df);
    
            Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {
                    DateIndexer, DayOfWeekIndexer, PdDistrictIndexer, encoder,
                    encoder1, encoder2, labelIndexer, assembler });
    
            // Train model. This also runs the indexers.
            PipelineModel model = pipeline.fit(df);
    
            // Make predictions.
            Dataset<Row> predictions = model.transform(df);
            predictions.describe("label").show();
            predictions.show(100, false);
            
            return predictions;
    
        }
    
    +-------------------+---------+----------+--------------+--------+-----+-----+-------------+--------------+-----------------------+-----+---------------------------------------------+
    |Dates |DayOfWeek|PdDistrict|Category |DatesNum|dfNum|pdNum|dfvec |pdvec |Datesvec |label|features |
    +-------------------+---------+----------+--------------+--------+-----+-----+-------------+--------------+-----------------------+-----+---------------------------------------------+
    |2015-05-13 23:53:00|Wednesday|NORTHERN |WARRANTS |172231.0|1.0 |2.0 |(7,[1],[1.0])|(10,[2],[1.0])|(389257,[172231],[1.0])|7.0 |(389274,[172231,389258,389266],[1.0,1.0,1.0])|
    |2015-05-13 23:53:00|Wednesday|NORTHERN |OTHER OFFENSES|172231.0|1.0 |2.0 |(7,[1],[1.0])|(10,[2],[1.0])|(389257,[172231],[1.0])|1.0 |(389274,[172231,389258,389266],[1.0,1.0,1.0])|
    |2015-05-13 18:05:00|Wednesday|BAYVIEW |LARCENY/THEFT |330092.0|1.0 |3.0 |(7,[1],[1.0])|(10,[3],[1.0])|(389257,[330092],[1.0])|0.0 |(389274,[330092,389258,389267],[1.0,1.0,1.0])|
    |2015-05-13 18:02:00|Wednesday|MISSION |OTHER OFFENSES|387792.0|1.0 |1.0 |(7,[1],[1.0])|(10,[1],[1.0])|(389257,[387792],[1.0])|1.0 |(389274,[387792,389258,389265],[1.0,1.0,1.0])|
    |2015-05-13 18:00:00|Wednesday|SOUTHERN |BURGLARY |32607.0 |1.0 |0.0 |(7,[1],[1.0])|(10,[0],[1.0])|(389257,[32607],[1.0]) |8.0 |(389274,[32607,389258,389264],[1.0,1.0,1.0]) |
    |2015-05-13 18:00:00|Wednesday|BAYVIEW |LARCENY/THEFT |32607.0 |1.0 |3.0 |(7,[1],[1.0])|(10,[3],[1.0])|(389257,[32607],[1.0]) |0.0 |(389274,[32607,389258,389267],[1.0,1.0,1.0]) |
    |2015-05-13 18:00:00|Wednesday|PARK |LARCENY/THEFT |32607.0 |1.0 |8.0 |(7,[1],[1.0])|(10,[8],[1.0])|(389257,[32607],[1.0]) |0.0 |(389274,[32607,389258,389272],[1.0,1.0,1.0]) |
    +-------------------+---------+----------+--------------+--------+-----+-----+-------------+--------------+-----------------------+-----+---------------------------------------------+
    only showing top 7 rows
    *******************************************************************************************************************
    
    数据2:
    
    id,name,age,sex,rate
    1,lyy,20,F,0.6
    2,rdd,20,M,0.4
    3,nyc,18,M,0.55
    4,mzy,10,M,0.21
    1 //Binarizer二值化: 将该列数据二值化,大于阈值的为1.0,否则为0.0  spark源码:udf { in: Double => if (in > td) 1.0 else 0.0 }
    2 
    3 Dataset<Row> result = new Binarizer()
    4                 .setInputCol("rate")
    5                 .setOutputCol("flag")
    6                 .setThreshold(0.5).transform(data);
    7                 
    8                 result.show(10, false);
    +---+----+---+---+----+----+
    |id |name|age|sex|rate|flag|
    +---+----+---+---+----+----+
    |1 |lyy |20 |F |0.6 |1.0 |
    |2 |rdd |20 |M |0.4 |0.0 |
    |3 |nyc |18 |M |0.55|1.0 |
    |4 |mzy |10 |M |0.21|0.0 |
    +---+----+---+---+----+----+
     1 //IndexToString将stringindexder转换的数据转回到原始的数据
     2 
     3  StringIndexer labelIndexer = new StringIndexer()
     4                  .setInputCol("sex")
     5                  .setOutputCol("label");
     6                  
     7                  IndexToString IndexToSex = new  IndexToString()
     8                              .setInputCol("label")
     9                              .setOutputCol("orisex");
    10                  
    11                  Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {    labelIndexer,    IndexToSex});
    12                  PipelineModel model = pipeline.fit(data);
    13  
    14                  // Make predictions.
    15                  Dataset<Row> result = model.transform(data);
    16                  
    17                  result.show(10, false);

     

     1                 //Bucketizer 分箱(分段处理):将连续数值转换为离散类别
     2                 //比如特征是年龄,是一个连续数值,需要将其转换为离散类别(未成年人、青年人、中年人、老年人),就要用到Bucketizer了
     3                 //如age > 55 老年人
     4                 double[] splits={0,18,35,55,Double.POSITIVE_INFINITY};//[0,18),[18,35),[35,55),[55,正无穷)
     5                 Dataset<Row> result=new Bucketizer()
     6                  .setInputCol("age")
     7                  .setOutputCol("bucketCategory")
     8                  .setSplits(splits)//设置分段标准
     9                  .transform(data);
    10 
    11                 result.show(10, false);

    
    
    
  • 相关阅读:
    final .....finally ...... 和Finalize ......区别
    MyEclipse中常用的快捷键大全,快来.....
    简单的描述Java中的构造函数,及访问修饰符
    分层开发---酒店管理系统---
    C#深入.NET平台的软件系统分层开发
    影院售票系统-----一个让你有成就感的小项目,只有一丢丢哦
    mysql数据库进阶
    MySQL练习
    MySQL中常见函数
    TCP协议之三次握手四次挥手
  • 原文地址:https://www.cnblogs.com/lyy-blog/p/9518177.html
Copyright © 2011-2022 走看看