zoukankan      html  css  js  c++  java
  • Spark:求出分组内的TopN

    制作测试数据源:

    c1 85
    c2 77
    c3 88
    c1 22
    c1 66
    c3 95
    c3 54
    c2 91
    c2 66
    c1 54
    c1 65
    c2 41
    c4 65

    spark scala实现代码:

    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    object GroupTopN1 {
      System.setProperty("hadoop.home.dir", "D:\Java_Study\hadoop-common-2.2.0-bin-master")
    
      case class Rating(userId: String, rating: Long)
    
      def main(args: Array[String]) {
        val sparkConf = new SparkConf().setAppName("ALS with ML Pipeline")
        val spark = SparkSession
          .builder()
          .config(sparkConf)
          .master("local")
          .config("spark.sql.warehouse.dir", "/")
          .getOrCreate()
    
        import spark.implicits._
        import spark.sql
    
        val lines = spark.read.textFile("C:\Users\Administrator\Desktop\group.txt")
        val classScores = lines.map(line => Rating(line.split(" ")(0).toString, line.split(" ")(1).toLong))
    
        classScores.createOrReplaceTempView("tb_test")
    
        var df = sql(
          s"""|select
              | userId,
              | rating,
              | row_number()over(partition by userId order by rating desc) rn
              |from tb_test
              |having(rn<=3)
              |""".stripMargin)
        df.show()
    
        spark.stop()
      }
    }

    打印结果:

    +------+------+---+
    |userId|rating| rn|
    +------+------+---+
    |    c1|    85|  1|
    |    c1|    66|  2|
    |    c1|    65|  3|
    |    c4|    65|  1|
    |    c3|    95|  1|
    |    c3|    88|  2|
    |    c3|    54|  3|
    |    c2|    91|  1|
    |    c2|    77|  2|
    |    c2|    66|  3|
    +------+------+---+

    spark java代码实现:

    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.MapFunction;
    import org.apache.spark.sql.*;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    import scala.Function1;
    
    import javax.management.RuntimeErrorException;
    import java.util.List;
    import java.util.ArrayList;
    
    public class Test {
        public static void main(String[] args) {
            System.out.println("Hello");
            SparkConf sparkConf = new SparkConf().setAppName("ALS with ML Pipeline");
            SparkSession spark = SparkSession
                    .builder()
                    .config(sparkConf)
                    .master("local")
                    .config("spark.sql.warehouse.dir", "/")
                    .getOrCreate();
    
    
            // Create an RDD
            JavaRDD<String> peopleRDD = spark.sparkContext()
                    .textFile("C:\Users\Administrator\Desktop\group.txt", 1)
                    .toJavaRDD();
    
            // The schema is encoded in a string
            String schemaString = "userId rating";
    
            // Generate the schema based on the string of schema
            List<StructField> fields = new ArrayList<>();
            StructField field1 = DataTypes.createStructField("userId", DataTypes.StringType, true);
            StructField field2 = DataTypes.createStructField("rating", DataTypes.LongType, true);
            fields.add(field1);
            fields.add(field2);
            StructType schema = DataTypes.createStructType(fields);
    
            // Convert records of the RDD (people) to Rows
            JavaRDD<Row> rowRDD = peopleRDD.map((Function<String, Row>) record -> {
                String[] attributes = record.split(" ");
                if(attributes.length!=2){
                    throw new Exception();
                }
                return RowFactory.create(attributes[0],Long.valueOf( attributes[1].trim()));
            });
    
            // Apply the schema to the RDD
            Dataset<Row> peopleDataFrame = spark.createDataFrame(rowRDD, schema);
    
            peopleDataFrame.createOrReplaceTempView("tb_test");
    
            Dataset<Row> items = spark.sql("select userId,rating,row_number()over(partition by userId order by rating desc) rn " +
                    "from tb_test " +
                    "having(rn<=3)");
            items.show();
    
            spark.stop();
        }
    }

    输出结果同上边输出结果。

    Java 中使用combineByKey实现TopN:

    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.FlatMapFunction;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.Function2;
    import org.apache.spark.api.java.function.PairFunction;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    import scala.Tuple2;
    
    import java.util.*;
    
    public class SparkJava {
        public static void main(String[] args) {
            SparkSession spark = SparkSession.builder().master("local[*]").appName("Spark").getOrCreate();
            final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext());
    
            List<String> data = Arrays.asList("a,110,a1", "b,122,b1", "c,123,c1", "a,210,a2", "b,212,b2", "a,310,a3", "b,312,b3", "a,410,a4", "b,412,b4");
            JavaRDD<String> javaRDD = ctx.parallelize(data);
    
            JavaPairRDD<String, Integer> javaPairRDD = javaRDD.mapToPair(new PairFunction<String, String, Integer>() {
                public Tuple2<String, Integer> call(String key) throws Exception {
                    return new Tuple2<String, Integer>(key.split(",")[0], Integer.valueOf(key.split(",")[1]));
                }
            });
    
            final int topN = 3;
            JavaPairRDD<String, List<Integer>> combineByKeyRDD2 = javaPairRDD.combineByKey(new Function<Integer, List<Integer>>() {
                public List<Integer> call(Integer v1) throws Exception {
                    List<Integer> items = new ArrayList<Integer>();
                    items.add(v1);
                    return items;
                }
            }, new Function2<List<Integer>, Integer, List<Integer>>() {
                public List<Integer> call(List<Integer> v1, Integer v2) throws Exception {
                    if (v1.size() > topN) {
                        Integer item = Collections.min(v1);
                        v1.remove(item);
                        v1.add(v2);
                    }
                    return v1;
                }
            }, new Function2<List<Integer>, List<Integer>, List<Integer>>() {
                public List<Integer> call(List<Integer> v1, List<Integer> v2) throws Exception {
                    v1.addAll(v2);
                    while (v1.size() > topN) {
                        Integer item = Collections.min(v1);
                        v1.remove(item);
                    }
    
                    return v1;
                }
            });
    
            // 由K:String,V:List<Integer> 转化为 K:String,V:Integer
            // old:[(a,[210, 310, 410]), (b,[122, 212, 312]), (c,[123])]
            // new:[(a,210), (a,310), (a,410), (b,122), (b,212), (b,312), (c,123)]
            JavaRDD<Tuple2<String, Integer>> javaTupleRDD = combineByKeyRDD2.flatMap(new FlatMapFunction<Tuple2<String, List<Integer>>, Tuple2<String, Integer>>() {
                public Iterator<Tuple2<String, Integer>> call(Tuple2<String, List<Integer>> stringListTuple2) throws Exception {
                    List<Tuple2<String, Integer>> items=new ArrayList<Tuple2<String, Integer>>();
                    for(Integer v:stringListTuple2._2){
                        items.add(new Tuple2<String, Integer>(stringListTuple2._1,v));
                    }
                    return items.iterator();
                }
            });
    
            JavaRDD<Row> rowRDD = javaTupleRDD.map(new Function<Tuple2<String, Integer>, Row>() {
                public Row call(Tuple2<String, Integer> kv) throws Exception {
                    String key = kv._1;
                    Integer num = kv._2;
    
                    return RowFactory.create(key, num);
                }
            });
    
            ArrayList<StructField> fields = new ArrayList<StructField>();
            StructField field = null;
            field = DataTypes.createStructField("key", DataTypes.StringType, true);
            fields.add(field);
            field = DataTypes.createStructField("TopN_values", DataTypes.IntegerType, true);
            fields.add(field);
    
            StructType schema = DataTypes.createStructType(fields);
    
            Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
            df.printSchema();
            df.show();
    
            spark.stop();
        }
    }

    输出:

    root
     |-- key: string (nullable = true)
     |-- TopN_values: integer (nullable = true)
    
    +---+-----------+
    |key|TopN_values|
    +---+-----------+
    |  a|        210|
    |  a|        310|
    |  a|        410|
    |  b|        122|
    |  b|        212|
    |  b|        312|
    |  c|        123|
    +---+-----------+

    Spark使用combineByKeyWithClassTag函数实现TopN

    combineByKeyWithClassTag函数,借助HashSet的排序,此例是取组内最大的N个元素一下是代码:

    • createcombiner就简单的将首个元素装进HashSet然后返回就可以了;
    • mergevalue插入元素之后,如果元素的个数大于N就删除最小的元素;
    • mergeCombiner在合并之后,如果总的个数大于N,就从一次删除最小的元素,知道Hashset内只有N 个元素。
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.SparkSession
    
    import scala.collection.mutable
    
    object Main {
      val N = 3
    
      def main(args: Array[String]): Unit = {
        val spark = SparkSession
          .builder()
          .master("local[*]")
          .appName("Spark")
          .getOrCreate()
        val sc = spark.sparkContext
        var SampleDataset = List(
          ("apple.com", 3L),
          ("apple.com", 4L),
          ("apple.com", 1L),
          ("apple.com", 9L),
          ("google.com", 4L),
          ("google.com", 1L),
          ("google.com", 2L),
          ("google.com", 3L),
          ("google.com", 11L),
          ("google.com", 32L),
          ("slashdot.org", 11L),
          ("slashdot.org", 12L),
          ("slashdot.org", 13L),
          ("slashdot.org", 14L),
          ("slashdot.org", 15L),
          ("slashdot.org", 16L),
          ("slashdot.org", 17L),
          ("slashdot.org", 18L),
          ("microsoft.com", 5L),
          ("microsoft.com", 2L),
          ("microsoft.com", 6L),
          ("microsoft.com", 9L),
          ("google.com", 4L))
        val urdd: RDD[(String, Long)] = sc.parallelize(SampleDataset).map((t) => (t._1, t._2))
        var topNs = urdd.combineByKeyWithClassTag(
          //createCombiner
          (firstInt: Long) => {
            var uset = new mutable.TreeSet[Long]()
            uset += firstInt
          },
          // mergeValue
          (uset: mutable.TreeSet[Long], value: Long) => {
            uset += value
            while (uset.size > N) {
              uset.remove(uset.min)
            }
            uset
          },
          //mergeCombiners
          (uset1: mutable.TreeSet[Long], uset2: mutable.TreeSet[Long]) => {
            var resultSet = uset1 ++ uset2
            while (resultSet.size > N) {
              resultSet.remove(resultSet.min)
            }
            resultSet
          }
        )
        import spark.implicits._
        topNs.flatMap(rdd => {
          var uset = new mutable.HashSet[String]()
          for (i <- rdd._2.toList) {
            uset += rdd._1 + "/" + i.toString
          }
          uset
        }).map(rdd => {
          (rdd.split("/")(0), rdd.split("/")(1))
        }).toDF("key", "TopN_values").show()
      }
    }

    参考《https://blog.csdn.net/gpwner/article/details/78455234》

    输出结果:

    +-------------+-----------+
    |          key|TopN_values|
    +-------------+-----------+
    |   google.com|          4|
    |   google.com|         11|
    |   google.com|         32|
    |microsoft.com|          9|
    |microsoft.com|          6|
    |microsoft.com|          5|
    |    apple.com|          4|
    |    apple.com|          9|
    |    apple.com|          3|
    | slashdot.org|         16|
    | slashdot.org|         17|
    | slashdot.org|         18|
    +-------------+-----------+
  • 相关阅读:
    【CF720D】Slalom 扫描线+线段树
    【CF724F】Uniformly Branched Trees 动态规划
    【CF725G】Messages on a Tree 树链剖分+线段树
    【CF736D】Permutations 线性代数+高斯消元
    【CF799E】Aquarium decoration 线段树
    【CF739E】Gosha is hunting 贪心
    【CF744D】Hongcow Draws a Circle 二分+几何
    【BZOJ3774】最优选择 最小割
    【BZOJ2138】stone Hall定理+线段树
    【BZOJ4445】[Scoi2015]小凸想跑步 半平面交
  • 原文地址:https://www.cnblogs.com/yy3b2007com/p/9363474.html
Copyright © 2011-2022 走看看