zoukankan      html  css  js  c++  java
  • 在使用Pipeline串联多个stage时model和非model的区别

    train.csv数据:

    id,name,age,sex
    1,lyy,20,F
    2,rdd,20,M
    3,nyc,18,M
    4,mzy,10,M

    数据读取:

     1 SparkSession  spark = SparkSession.builder().enableHiveSupport()
     2                     .getOrCreate();
     3         Dataset<Row> dataset = spark
     4                 .read()
     5                 .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat")
     6                 .option("header", true)
     7                 .option("inferSchema", true)
     8                 .option("delimiter", ",")
     9                 //.load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/SanFranciscoCrime/document/kaggle-旧金山犯罪分类/train-new.csv") //PreProcess1
    10                 .load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/DataPreprocessing/document/train.csv") //PreProcess2
    11                 .persist();
     1     public static void PreProcess2(Dataset<Row> data) {
     2         
     3                 data.printSchema();
     4                 // 重新索引标签值
     5                 StringIndexerModel labelIndexer = new StringIndexer()
     6                 .setInputCol("sex")
     7                 .setOutputCol("label")
     8                 .fit(data);
     9                 
    10                 StringIndexerModel nameIndexer = new StringIndexer()
    11                 .setInputCol("name")
    12                 .setOutputCol("namenum")
    13                 .fit(data);
    14 
    15                 
    16                 /*  会报错:Exception in thread "main" java.lang.IllegalArgumentException: Field "namenum" does not exist.
    17                  * 原因是:Model类型调用fit时,要求数据集中必须包含InputCol所指定的列名
    18                  * 不会将Pipeline某个stage的输出作为InputCol,即使那个stage的OutputCol指定的列名与其相同也不行
    19                  * StringIndexerModel name1Indexer = new StringIndexer()
    20                 .setInputCol("namenum")
    21                 .setOutputCol("namenum1")
    22                 .fit(data);*/
    23                 
    24                 
    25                 /* 错误原因StringIndexerModel错误一样,features并不是data的列
    26                  * VectorIndexerModel featureIndexer = new VectorIndexer()
    27                     .setInputCol("features")
    28                     .setOutputCol("indexfeatures")
    29                     .setMaxCategories(4)
    30                     .fit(data);*/
    31                 
    32                 //成功
    33                 //原因说明:非model时,转换器不会调用fit,而会使用Pipeline某个stage的输出作为InputCol
    34                 //由于stage[2]即 assembler已经生成features,故而该处直接使用;
    35                 //但是该类型时不能单独使用,必须依赖Pipeline
    36                 VectorIndexer featureIndexer = new VectorIndexer()
    37                 .setInputCol("features")
    38                 .setOutputCol("indexfeatures")
    39                 .setMaxCategories(4);
    40                 
    41                 //由上述分析可知,该处输入的列可以是多个stage的输出组成,因为VectorAssembler非model
    42                 //因此可以使用中间生成结果,且可以使用多个
    43                 VectorAssembler assembler = new VectorAssembler()
    44                 .setInputCols("id,namenum,age".split(","))
    45                .setOutputCol("features");
    46                 
    47                 //这里的stage的顺序很重要,一定按照依赖关系顺序放入,如下顺序就会报错:
    48                 //Exception in thread "main" java.lang.IllegalArgumentException: Field "features" does not exist.
    49                 //Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,featureIndexer,assembler});
    50                 
    51                 //将featureIndexer放到assembler即可
    52                 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer});
    53 
    54                 // Train model. This also runs the indexers.
    55                 PipelineModel model = pipeline.fit(data);
    56 
    57                 // Make predictions.
    58                 Dataset<Row> result = model.transform(data);
    59                 
    60                 result.show(10, false);
    61                 
    62     }

    root
    |-- id: integer (nullable = true)
    |-- name: string (nullable = true)
    |-- age: integer (nullable = true)
    |-- sex: string (nullable = true)

    +---+----+---+---+-----+-------+--------------+-------------+
    |id |name|age|sex|label|namenum|features |indexfeatures|
    +---+----+---+---+-----+-------+--------------+-------------+
    |1 |lyy |20 |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
    |2 |rdd |20 |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
    |3 |nyc |18 |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
    |4 |mzy |10 |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
    +---+----+---+---+-----+-------+--------------+-------------+

    综上分析,可以将原有代码做一简化:

     1 public static void PreProcess2(Dataset<Row> data) {
     2         
     3                 data.printSchema();
     4                 // 重新索引标签值
     5                 StringIndexer labelIndexer = new StringIndexer()
     6                 .setInputCol("sex")
     7                 .setOutputCol("label");
     8                 
     9                 StringIndexer nameIndexer = new StringIndexer()
    10                 .setInputCol("name")
    11                 .setOutputCol("namenum");
    12 
    13                 VectorIndexer featureIndexer = new VectorIndexer()
    14                 .setInputCol("features")
    15                 .setOutputCol("indexfeatures")
    16                 .setMaxCategories(4);
    17                 
    18             
    19                 VectorAssembler assembler = new VectorAssembler()
    20                 .setInputCols("id,namenum,age".split(","))
    21                .setOutputCol("features");
    22 
    23                 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer});
    24 
    25                 // Train model. This also runs the indexers.
    26                 PipelineModel model = pipeline.fit(data);  //以这里的data为基准数据
    27 
    28                 // Make predictions.
    29                 Dataset<Row> result = model.transform(data);
    30                 
    31                 result.show(10, false);
    32                 
    33     }

    运行结果:

    root
     |-- id: integer (nullable = true)
     |-- name: string (nullable = true)
     |-- age: integer (nullable = true)
     |-- sex: string (nullable = true)
    
    +---+----+---+---+-----+-------+--------------+-------------+
    |id |name|age|sex|label|namenum|features      |indexfeatures|
    +---+----+---+---+-----+-------+--------------+-------------+
    |1  |lyy |20 |F  |1.0  |1.0    |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
    |2  |rdd |20 |M  |0.0  |2.0    |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
    |3  |nyc |18 |M  |0.0  |0.0    |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
    |4  |mzy |10 |M  |0.0  |3.0    |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
    +---+----+---+---+-----+-------+--------------+-------------+
  • 相关阅读:
    Scala之eq,equals,==的区别
    Spark Streaming流计算特点及代码案例
    刷题50—水壶问题
    刷题49(力扣3道题)
    刷题48——最长回文串
    刷题47——矩形重叠
    刷题46——拼写单词
    刷题45(力扣两道题)
    刷题44——岛屿的最大面积
    刷题43——最长上升子序列
  • 原文地址:https://www.cnblogs.com/lyy-blog/p/9523026.html
Copyright © 2011-2022 走看看