zoukankan      html  css  js  c++  java
  • Spark partitionBy

    partitionBy 重新分区, repartition默认采用HashPartitioner分区,自己设计合理的分区方法(比如数量比较大的key 加个随机数 随机分到更多的分区, 这样处理数据倾斜更彻底一些)

    /**
     * An object that defines how the elements in a key-value pair RDD are partitioned by key.
     * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
     */
    abstract class Partitioner extends Serializable {
      def numPartitions: Int
      def getPartition(key: Any): Int
    }
    import org.apache.spark.HashPartitioner
    import org.apache.spark.sql.SparkSession
    
    //查看rdd中的每个分区元素
    object PartitionBy_Test { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local").appName(this.getClass.getSimpleName).getOrCreate() val rdd = spark.sparkContext.parallelize(Array(("a", 1), ("a", 2), ("b", 1), ("b", 3), (("c", 1)), (("e", 1))), 2) val result = rdd.mapPartitionsWithIndex { (partIdx, iter) => { val part_map = scala.collection.mutable.Map[String, List[(String, Int)]]() while (iter.hasNext) { val part_name = "part_" + partIdx var elem = iter.next() if (part_map.contains(part_name)) { var elems = part_map(part_name) elems ::= elem part_map(part_name) = elems } else { part_map(part_name) = List[(String, Int)] { elem } } } part_map.iterator } }.collect result.foreach(x => println(x._1 + ":" + x._2.toString())) } }

    这里的分区方法可以选择, 默认的分区就是HashPartition分区,
    注意如果多次使用该RDD或者进行join操作, 分区后peresist持久化操作

    /**
     * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
     * Java's `Object.hashCode`.
     *
     * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
     * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
     * produce an unexpected or incorrect result.
     */
    class HashPartitioner(partitions: Int) extends Partitioner {
      require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
    
      def numPartitions: Int = partitions
    
      def getPartition(key: Any): Int = key match {
        case null => 0
        case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
      }
    
      override def equals(other: Any): Boolean = other match {
        case h: HashPartitioner =>
          h.numPartitions == numPartitions
        case _ =>
          false
      }
    
      override def hashCode: Int = numPartitions
    }

    范围分区 RangePartitioner :先键值排序, 确定样本大小,采样后不放回总体的随机采样方法, 分配键值的分区,通过样本采样避免数据倾斜。

    class RangePartitioner[K : Ordering : ClassTag, V](
        partitions: Int,
        rdd: RDD[_ <: Product2[K, V]],
        private var ascending: Boolean = true,
        val samplePointsPerPartitionHint: Int = 20)
      extends Partitioner {
    
      // A constructor declared in order to maintain backward compatibility for Java, when we add the
      // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160.
      // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor.
      def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = {
        this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20)
      }
    
      // We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
      require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")
      require(samplePointsPerPartitionHint > 0,
        s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint")
    
      private var ordering = implicitly[Ordering[K]]
    
      // An array of upper bounds for the first (partitions - 1) partitions
      private var rangeBounds: Array[K] = {
        if (partitions <= 1) {
          Array.empty
        } else {
          // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
          // Cast to double to avoid overflowing ints or longs
          val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)
          // Assume the input partitions are roughly balanced and over-sample a little bit.
          val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
          val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
          if (numItems == 0L) {
            Array.empty
          } else {
            // If a partition contains much more than the average number of items, we re-sample from it
            // to ensure that enough items are collected from that partition.
            val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
            val candidates = ArrayBuffer.empty[(K, Float)]
            val imbalancedPartitions = mutable.Set.empty[Int]
            sketched.foreach { case (idx, n, sample) =>
              if (fraction * n > sampleSizePerPartition) {
                imbalancedPartitions += idx
              } else {
                // The weight is 1 over the sampling probability.
                val weight = (n.toDouble / sample.length).toFloat
                for (key <- sample) {
                  candidates += ((key, weight))
                }
              }
            }
            if (imbalancedPartitions.nonEmpty) {
              // Re-sample imbalanced partitions with the desired sampling probability.
              val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
              val seed = byteswap32(-rdd.id - 1)
              val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
              val weight = (1.0 / fraction).toFloat
              candidates ++= reSampled.map(x => (x, weight))
            }
            RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))
          }
        }
      }
    
      def numPartitions: Int = rangeBounds.length + 1
    
      private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
    
      def getPartition(key: Any): Int = {
        val k = key.asInstanceOf[K]
        var partition = 0
        if (rangeBounds.length <= 128) {
          // If we have less than 128 partitions naive search
          while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
            partition += 1
          }
        } else {
          // Determine which binary search method to use only once.
          partition = binarySearch(rangeBounds, k)
          // binarySearch either returns the match location or -[insertion point]-1
          if (partition < 0) {
            partition = -partition-1
          }
          if (partition > rangeBounds.length) {
            partition = rangeBounds.length
          }
        }
        if (ascending) {
          partition
        } else {
          rangeBounds.length - partition
        }
      }
    
      override def equals(other: Any): Boolean = other match {
        case r: RangePartitioner[_, _] =>
          r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
        case _ =>
          false
      }
    
      override def hashCode(): Int = {
        val prime = 31
        var result = 1
        var i = 0
        while (i < rangeBounds.length) {
          result = prime * result + rangeBounds(i).hashCode
          i += 1
        }
        result = prime * result + ascending.hashCode
        result
      }
    
      @throws(classOf[IOException])
      private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
        val sfactory = SparkEnv.get.serializer
        sfactory match {
          case js: JavaSerializer => out.defaultWriteObject()
          case _ =>
            out.writeBoolean(ascending)
            out.writeObject(ordering)
            out.writeObject(binarySearch)
    
            val ser = sfactory.newInstance()
            Utils.serializeViaNestedStream(out, ser) { stream =>
              stream.writeObject(scala.reflect.classTag[Array[K]])
              stream.writeObject(rangeBounds)
            }
        }
      }
    
      @throws(classOf[IOException])
      private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
        val sfactory = SparkEnv.get.serializer
        sfactory match {
          case js: JavaSerializer => in.defaultReadObject()
          case _ =>
            ascending = in.readBoolean()
            ordering = in.readObject().asInstanceOf[Ordering[K]]
            binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]
    
            val ser = sfactory.newInstance()
            Utils.deserializeViaNestedStream(in, ser) { ds =>
              implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
              rangeBounds = ds.readObject[Array[K]]()
            }
        }
      }
    }

    自定义分区函数 自己根据业务数据减缓数据倾斜问题:
    要实现自定义的分区器,你需要继承 org.apache.spark.Partitioner 类并实现下面三个方法

    • numPartitions: Int:返回创建出来的分区数。
    • getPartition(key: Any): Int:返回给定键的分区编号( 0 到 numPartitions-1)。
    //自定义分区类,需继承Partitioner类
    class UsridPartitioner(numParts:Int) extends Partitioner{
      //覆盖分区数
      override def numPartitions: Int = numParts
      
      //覆盖分区号获取函数
      override def getPartition(key: Any): Int = {
         if(key.toString == "A")
               key.toString.toInt%10
         else:
              key.toString.toInt%5      
      }
    }
  • 相关阅读:
    【STM32F407开发板用户手册】第8章 STM32F407的终极调试组件Event Recorder
    【STM32F429开发板用户手册】第7章 STM32F429下载和调试方法(IAR8)
    【STM32F407开发板用户手册】第7章 STM32F407下载和调试方法(IAR8)
    【STM32F429开发板用户手册】第6章 STM32F429工程模板建立(IAR8)
    【STM32F407开发板用户手册】第6章 STM32F407工程模板建立(IAR8)
    【STM32F429开发板用户手册】第5章 STM32F429下载和调试方法(MDK5)
    【STM32F407开发板用户手册】第5章 STM32F407下载和调试方法(MDK5)
    基于STM32H7,F407,F429的ThreadX内核程序模板,含GCC,MDK和IAR三个版本(2020-06-08)
    【STM32F429开发板用户手册】第4章 STM32F429工程模板建立(MDK5)
    【STM32F407开发板用户手册】第4章 STM32F407工程模板建立(MDK5)
  • 原文地址:https://www.cnblogs.com/itboys/p/9853691.html
Copyright © 2011-2022 走看看