zoukankan      html  css  js  c++  java
  • Spark 两种方法计算分组取Top N

    Spark 分组取Top N运算

    大数据处理中,对数据分组后,取TopN是非常常见的运算。

    下面我们以一个例子来展示spark如何进行分组取Top的运算。

    1、RDD方法分组取TopN

    from pyspark import SparkContext
    sc = SparkContext()
    

    准备数据,把数据转换为rdd格式

    data_list = [
     (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),
     (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),
     (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),
     (3, "cat8", 135.6)
    ]
    
    data = sc.parallelize(data_list)
    data.collect()
    
    [(0, 'cat26', 130.9),
     (0, 'cat13', 122.1),
     (0, 'cat95', 119.6),
     (0, 'cat105', 11.3),
     (1, 'cat67', 128.5),
     (1, 'cat4', 126.8),
     (1, 'cat13', 112.6),
     (1, 'cat23', 15.3),
     (2, 'cat56', 139.6),
     (2, 'cat40', 129.7),
     (2, 'cat187', 127.9),
     (2, 'cat68', 19.8),
     (3, 'cat8', 135.6)]
    

    对数据使用groupBy操作来分组。可以看到分组后数据为(key, list_data)

    d1 = data.groupBy(lambda x:x[0])
    temp = d1.collect()
    print(list(temp[0][1]))
    print(temp)
    
    [(0, 'cat26', 130.9), (0, 'cat13', 122.1), (0, 'cat95', 119.6), (0, 'cat105', 11.3)]
    [(0, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C710>), (1, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C780>), (2, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C898>), (3, <pyspark.resultiterable.ResultIterable object at 0x0000000007D2C9B0>)]
    

    使用mapValues方法对数据进行排序。

    可以根据需要来取Top N 数据。

    这里取Top 3 的数据

    d2 = d1.mapValues(lambda x: sorted(x, key=lambda y:y[2])[:3])
    d2.collect()
    
    [(0, [(0, 'cat105', 11.3), (0, 'cat95', 119.6), (0, 'cat13', 122.1)]),
     (1, [(1, 'cat23', 15.3), (1, 'cat13', 112.6), (1, 'cat4', 126.8)]),
     (2, [(2, 'cat68', 19.8), (2, 'cat187', 127.9), (2, 'cat40', 129.7)]),
     (3, [(3, 'cat8', 135.6)])]
    

    使用flatmap方法把结果拉平,变成一个list返回。

    d3 = d2.flatMap(lambda x:[i for i in x[1]])
    d3.collect()
    
    [(0, 'cat105', 11.3),
     (0, 'cat95', 119.6),
     (0, 'cat13', 122.1),
     (1, 'cat23', 15.3),
     (1, 'cat13', 112.6),
     (1, 'cat4', 126.8),
     (2, 'cat68', 19.8),
     (2, 'cat187', 127.9),
     (2, 'cat40', 129.7),
     (3, 'cat8', 135.6)]
    

    整体代码

    from pyspark import SparkContext
    # sc = SparkContext()
    
    topN = 3
    data_list = [
     (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),
     (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),
     (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),
     (3, "cat8", 135.6)
    ]
    
    data = sc.parallelize(data_list)
    d1 = data.groupBy(lambda x:x[0])
    d2 = d1.mapValues(lambda x: sorted(x, key=lambda y:y[2])[:topN])
    d3 = d2.flatMap(lambda x:[i for i in x[1]])
    d3.collect()
    
    [(0, 'cat105', 11.3),
     (0, 'cat95', 119.6),
     (0, 'cat13', 122.1),
     (1, 'cat23', 15.3),
     (1, 'cat13', 112.6),
     (1, 'cat4', 126.8),
     (2, 'cat68', 19.8),
     (2, 'cat187', 127.9),
     (2, 'cat40', 129.7),
     (3, 'cat8', 135.6)]
    

    2、Dataframe方法分组取TopN

    dataframe数据格式分组取top N,简单的方法是使用Window方法

    from pyspark.sql import SparkSession
    from pyspark.sql import functions as func
    from pyspark.sql import Window
    
    spark = SparkSession.builder.getOrCreate()
    
    data_list = [
     (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),
     (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),
     (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),
     (3, "cat8", 135.6)
    ]
    
    
    根据数据创建dataframe,并给数据列命名
    
    df = spark.createDataFrame(data_list, ["Hour", "Category", "TotalValue"])
    df.show()
    
    +----+--------+----------+
    |Hour|Category|TotalValue|
    +----+--------+----------+
    | 0| cat26| 130.9|
    | 0| cat13| 122.1|
    | 0| cat95| 119.6|
    | 0| cat105| 11.3|
    | 1| cat67| 128.5|
    | 1| cat4| 126.8|
    | 1| cat13| 112.6|
    | 1| cat23| 15.3|
    | 2| cat56| 139.6|
    | 2| cat40| 129.7|
    | 2| cat187| 127.9|
    | 2| cat68| 19.8|
    | 3| cat8| 135.6|
    +----+--------+----------+
    
    1. 使用窗口方法,分片参数为分组的key,

    2. orderBy的参数为排序的key,这里使用desc降序排列。

    3. withColumn(colName, col),为df添加一列,数据为对window函数生成的数据编号

    4. where方法取rn列值小于3的数据,即取top3数据

    w = Window.partitionBy(df.Hour).orderBy(df.TotalValue.desc())
    top3 = df.withColumn('rn', func.row_number().over(w)).where('rn <=3')
    top3.show()
    
    +----+--------+----------+---+
    |Hour|Category|TotalValue| rn|
    +----+--------+----------+---+
    | 0| cat26| 130.9| 1|
    | 0| cat13| 122.1| 2|
    | 0| cat95| 119.6| 3|
    | 1| cat67| 128.5| 1|
    | 1| cat4| 126.8| 2|
    | 1| cat13| 112.6| 3|
    | 3| cat8| 135.6| 1|
    | 2| cat56| 139.6| 1|
    | 2| cat40| 129.7| 2|
    | 2| cat187| 127.9| 3|
    +----+--------+----------+---+
    
    ### 代码汇总
    
    from pyspark.sql import SparkSession
    from pyspark.sql import functions as func
    from pyspark.sql import Window
    
    spark = SparkSession.builder.getOrCreate()
    
    data_list = [
     (0, "cat26", 130.9), (0, "cat13", 122.1), (0, "cat95", 119.6), (0, "cat105", 11.3),
     (1, "cat67", 128.5), (1, "cat4", 126.8), (1, "cat13", 112.6), (1, "cat23", 15.3),
     (2, "cat56", 139.6), (2, "cat40", 129.7), (2, "cat187", 127.9), (2, "cat68", 19.8),
     (3, "cat8", 135.6)
    ]
    df = spark.createDataFrame(data_list, ["Hour", "Category", "TotalValue"])
    
    w = Window.partitionBy(df.Hour).orderBy(df.TotalValue.desc())
    top3 = df.withColumn('rn', func.row_number().over(w)).where('rn <=3')
    
    top3.show()
    
  • 相关阅读:
    The test form is only available for requests from the local machine
    64位Win7下,先安装Visual Studio,后安装IIS的设置步骤
    [转] 如何在 64 位的 Windows 7 中安裝 PLSQL DEVELOPER 8 和 Oracle 11g x64 Client
    excel对csv的转义
    js中没有引用的匿名函数调用方法
    缓存实现条件
    js对象成员的删除特性 (delete)
    js语法作用域之间的相关性
    【转】UBOOT之四:uboot.lds分析
    linux C 中的volatile使用
  • 原文地址:https://www.cnblogs.com/StitchSun/p/13255096.html
Copyright © 2011-2022 走看看