zoukankan      html  css  js  c++  java
  • Spark自定义分区(Partitioner)

    Spark提供了HashPartitioner和RangePartitioner两种分区策略

    ,这两种分区策略在很多情况下都适合我们的场景。但是有些情况下,Spark内部不能符合咱们的需求
    ,这时候我们就可以自定义分区策略。
    为此,Spark提供了相应的接口,我们只需要扩展Partitioner抽象类,然后实现里面的方法。

    Partitioner类如下

    /**
     * An object that defines how the elements in a key-value pair RDD are partitioned by key.
     * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
     */
    abstract class Partitioner extends Serializable {
        //这个方法返回你要创建分区的个数;
      def numPartitions: Int 
        //这个方法对输入的key做计算,返回该key对应的分区ID,范围是0到numPartitions-1
      def getPartition(key: Any): Int 
    }

    spark默认的实现是hashPartitioner,看一下它的实现方法:

    /**
     * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
     * Java's `Object.hashCode`.
     *
     * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
     * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
     * produce an unexpected or incorrect result.
     */
    class HashPartitioner(partitions: Int) extends Partitioner {
      require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
    
      def numPartitions: Int = partitions
    
      def getPartition(key: Any): Int = key match {
        case null => 0
        case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
      }
        //这个是Java标准的判断相等的函数,这个函数是因为Spark内部会比较两个RDD的分区是否一样。
      override def equals(other: Any): Boolean = other match {
        case h: HashPartitioner =>
          h.numPartitions == numPartitions
        case _ =>
          false
      }
    
      override def hashCode: Int = numPartitions
    }

    nonNegativeMod方法:

     /* Calculates 'x' modulo 'mod', takes to consideration sign of x,
      * i.e. if 'x' is negative, than 'x' % 'mod' is negative too
      * so function return (x % mod) + mod in that case.
      */
      def nonNegativeMod(x: Int, mod: Int): Int = {
        val rawMod = x % mod
        rawMod + (if (rawMod < 0) mod else 0)
      }

    举个例子

            //将jack、world相关的元素分到单独的分区中
            JavaRDD<String> javaRDD =jsc.parallelize(Arrays.asList("jack1", "jack2", "jack3"
                    , "world1", "world2", "world3"));

    自定义partitioner

    import org.apache.spark.Partitioner;
    
    /**
     * 自定义Partitioner
     */
    public class MyPartitioner extends Partitioner {
    
        private int numPartitions;
    
        public MyPartitioner(int numPartitions){
            this.numPartitions = numPartitions;
        }
        @Override
        public int numPartitions() {
            return numPartitions;
        }
    
        @Override
        public int getPartition(Object key) {
            if(key == null){
                return 0;
            }
            String str = key.toString();
            int hashCode = str.substring(0, str.length() - 1).hashCode();
            return nonNegativeMod(hashCode,numPartitions);
        }
    
        public boolean equals(Object obj) {
            if (obj instanceof MyPartitioner) {
                return ((MyPartitioner) obj).numPartitions == numPartitions;
            }
            return false;
        }
    
        //Utils.nonNegativeMod(key.hashCode, numPartitions)
        private int nonNegativeMod(int hashCode,int numPartitions){
            int rawMod = hashCode % numPartitions;
            if(rawMod < 0){
                rawMod = rawMod + numPartitions;
            }
            return rawMod;
        }
    
    }

    然后我们在partitionBy()方法里面使用自定义的partitioner,测试示例:

            //将jack、world相关的元素分到单独的分区中
            JavaRDD<String> javaRDD =jsc.parallelize(Arrays.asList("jack1", "jack2", "jack3"
                    , "world1", "world2", "world3"));
            //自定义partitioner需要在pairRDD的基础上调用
            JavaPairRDD<String, Integer> pairRDD = javaRDD.mapToPair(s -> new Tuple2<>(s, 1));
            JavaPairRDD<String, Integer> pairRDD1 = pairRDD.partitionBy(new MyPartitioner(2));
            System.out.println("指定分区之后的分区数:"+pairRDD1.getNumPartitions());
    
            pairRDD1.mapPartitionsWithIndex((v1, v2) -> {
                ArrayList<String> result = new ArrayList<>();
                while (v2.hasNext()){
                    result.add(v1+"_"+v2.next());
                }
                return result.iterator();
            },true).foreach(s -> System.out.println(s));

    输出

    指定分区之后的分区数:2
    0_(world1,1)
    0_(world2,1)
    0_(world3,1)
    1_(jack1,1)
    1_(jack2,1)
    1_(jack3,1)

    参考:https://my.oschina.net/u/939952/blog/1863372

    参考:https://www.iteblog.com/archives/1368.html

  • 相关阅读:
    IDEA 中 右键新建时,没有新建class的解决方案
    Git--删除远程仓库文件但不删除本地仓库资源
    Git——跟踪或取消跟踪文件
    git命令大杂烩
    判断项目中是否有slf4j的实现类
    完美解决在Servlet中出现一个输出中文乱码的问题
    mysql常用命令和语句
    设置idea快速生成doc comment
    关于pom.xml中的dependency中的顺序
    Pyqt5_QMessageBox
  • 原文地址:https://www.cnblogs.com/zz-ksw/p/12454324.html
Copyright © 2011-2022 走看看