spark aggregate源代码
/** * Aggregate the elements of each partition, and then the results for all the partitions, using * given combine functions and a neutral "zero value". This function can return a different result * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. */ def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U], combOp: JFunction2[U, U, U]): U = rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
aggregate用于聚合RDD中的元素,先使用seqOp将RDD中每个分区中的T类型元素聚合成U类型,
再使用combOp将之前每个分区聚合后的U类型聚合成U类型,注意seqOp和combOp都会使用zeroValue的值,zeroValue的类型为U。
样例代码:
需要注意的是:
单分区和多分区是不一样的。
List<Integer> list = new ArrayList<>(); list.add(2); list.add(3); list.add(2); list.add(5); list.add(2); list.add(6); //单分区情况下 JavaRDD<Integer> rdd1 = sc.parallelize(list,1); System.out.println("NumPartitions :"+rdd1.getNumPartitions()); int result1 = rdd1.aggregate(1, new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 * v2;//等同于zeroValue*2得到的值再*3...同理得到的值再*2*5*2*6等于720 } }, new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2;//等同于zeroValue+前面函数得到的值,也就是1+720=721 } }); System.out.println("result1: "+result1); //多分区情况下 JavaRDD<Integer> rdd2 = sc.parallelize(list,2); System.out.println("NumPartitions :"+rdd2.getNumPartitions()); JavaRDD<String> mapPartitionsWithIndex = rdd2.mapPartitionsWithIndex(new Function2<Integer, Iterator<Integer>, Iterator<String>>() { @Override public Iterator<String> call(Integer part_id, Iterator<Integer> iterator) throws Exception { List<String> list = new ArrayList<>(); while (iterator.hasNext()) { list.add("partition" + part_id + ":" + iterator.next()); } return list.iterator(); } }, true); mapPartitionsWithIndex.foreachPartition((VoidFunction<Iterator<String>>) iterator -> { while (iterator.hasNext()) { System.out.println(iterator.next()); } }); //输出结果: // partition0:2 // partition0:3 // partition0:2 // partition1:5 // partition1:2 // partition1:6 int result2 = rdd2.aggregate(2, new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 * v2; //这次修改zeroValue为2 //partition0中的元素有2,3,2 计算结果是2*2*3*2=24 其中2指zeroValue //partition0中的元素有5,2,6 计算结果是2*5*2*6=120 其中2指zeroValue } }, new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; //计算结果2+24+120=146,其中2指zeroValue } }); System.out.println("result2: "+result2);