zoukankan      html  css  js  c++  java
  • minhash pyspark 源码分析——hash join table是关键

    从下面分析可以看出,是先做了hash计算,然后使用hash join table来讲hash值相等的数据合并在一起。然后再使用udf计算距离,最后再filter出满足阈值的数据:

    参考:https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
      /**
       * Join two datasets to approximately find all pairs of rows whose distance are smaller than
       * the threshold. If the [[outputCol]] is missing, the method will transform the data; if the
       * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed
       * data when necessary.
       *
       * @param datasetA One of the datasets to join.
       * @param datasetB Another dataset to join.
       * @param threshold The threshold for the distance of row pairs.
       * @param distCol Output column for storing the distance between each pair of rows.
       * @return A joined dataset containing pairs of rows. The original rows are in columns
       *         "datasetA" and "datasetB", and a column "distCol" is added to show the distance
       *         between each pair.
       */
      def approxSimilarityJoin(
          datasetA: Dataset[_],
          datasetB: Dataset[_],
          threshold: Double,
          distCol: String): Dataset[_] = {
    
        val leftColName = "datasetA"
        val rightColName = "datasetB"
        val explodeCols = Seq("entry", "hashValue")
        val explodedA = processDataset(datasetA, leftColName, explodeCols)
    
        // If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity.
        // TODO: Remove recreateCol logic once SPARK-17154 is resolved.
        val explodedB = if (datasetA != datasetB) {
          processDataset(datasetB, rightColName, explodeCols)
        } else {
          val recreatedB = recreateCol(datasetB, $(inputCol), s"${$(inputCol)}#${Random.nextString(5)}")
          processDataset(recreatedB, rightColName, explodeCols)
        }
    
        // Do a hash join on where the exploded hash values are equal.
        val joinedDataset = explodedA.join(explodedB, explodeCols)
          .drop(explodeCols: _*).distinct()
    
        // Add a new column to store the distance of the two rows.
        val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)
        val joinedDatasetWithDist = joinedDataset.select(col("*"),
          distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol)
        )
    
        // Filter the joined datasets where the distance are smaller than the threshold.
        joinedDatasetWithDist.filter(col(distCol) < threshold)
      }
    

    补充:

    sql join 算法 时间复杂度

    参考

    stackoverflow

    笔记

    sql语句如下:

    SELECT  T1.name, T2.date
    FROM    T1, T2
    WHERE   T1.id=T2.id
            AND T1.color='red'
            AND T2.type='CAR'

    假设T1有m行,T2有n行,那么,普通情况下,应该要遍历T1的每一行的id(m),然后在遍历T2(n)中找出T2.id = T1.id的行进行join。时间复杂度应该是O(m*n)

    如果没有索引的话,engine会选择hash join或者merge join进行优化。

    hash join是这样的:

    1. 选择被哈希的表,通常是小一点的表。让我们愉快地假定是T1更小吧。
    2. T1所有的记录都被遍历。如果记录符合color=’red’,这条记录就会进去哈希表,以id为key,以name为value。
    3. T2所有的记录被遍历。如果记录符合type=’CAR’,使用这条记录的id去搜索哈希表,所有命中的记录的name的值,都被返回,还带上了当前记录的date的值,这样就可以把两者join起来了。

    时间复杂度O(n+m),实现hash表是O(n),hash表查找是O(m),直接将其相加。

    merge join是这样的:

    1.复制T1(id, name),根据id排序。
    2.复制T2(id, date),根据id排序。
    3.两个指针指向两个表的最小值。

        >1 2<
         2 3
         2 4
         3 5

    4.在循环中比较指针,如果match,就返回记录。如果不match,指向较小值的指针指向下一个记录。

    >1  2<  - 不match, 左指针小,左指针++
     2  3
     2  4
     3  5
    
     1  2<  - match, 返回记录,两个指针都++
    >2  3
     2  4
     3  5
    
     1  2  - match, 返回记录,两个指针都++
     2  3< 
     2  4
    >3  5
    
     1  2 - 左指针越界,查询结束。
     2  3
     2  4<
     3  5
    >

    时间复杂度O(n*log(n)+m*log(m))。排序算法的复杂度分别是O(n*log(n))和O(m*log(m)),直接将两者相加。

    在这种情况下,使查询更加复杂反而可以加快速度,因为更少的行需要经受join-level的测试?

    当然了。

    如果原来的query没有where语句,如

    SELECT  T1.name, T2.date
    FROM    T1, T2

    是更简单的,但是会返回更多的结果并运行更长的时间。

      

    hash函数的补充:

    可以看到 hashFunction 涉及到indices 字段下表的计算。另外的distance计算使用了jaccard相似度。

    from:https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala

    /**
     * :: Experimental ::
     *
     * Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function
     * is picked from the following family of hash functions, where a_i and b_i are randomly chosen
     * integers less than prime:
     *    `h_i(x) = ((x cdot a_i + b_i) mod prime)`
     *
     * This hash family is approximately min-wise independent according to the reference.
     *
     * Reference:
     * Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear permutations."
     * Electronic Journal of Combinatorics 7 (2000): R26.
     *
     * @param randCoefficients Pairs of random coefficients. Each pair is used by one hash function.
     */
    @Experimental
    @Since("2.1.0")
    class MinHashLSHModel private[ml](
        override val uid: String,
        private[ml] val randCoefficients: Array[(Int, Int)])
      extends LSHModel[MinHashLSHModel] {
    
      /** @group setParam */
      @Since("2.4.0")
      override def setInputCol(value: String): this.type = super.set(inputCol, value)
    
      /** @group setParam */
      @Since("2.4.0")
      override def setOutputCol(value: String): this.type = super.set(outputCol, value)
    
      @Since("2.1.0")
      override protected[ml] def hashFunction(elems: Vector): Array[Vector] = {
        require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
        val elemsList = elems.toSparse.indices.toList
        val hashValues = randCoefficients.map { case (a, b) =>
          elemsList.map { elem: Int =>
            ((1L + elem) * a + b) % MinHashLSH.HASH_PRIME
          }.min.toDouble
        }
        // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
        hashValues.map(Vectors.dense(_))
      }
    
      @Since("2.1.0")
      override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
        val xSet = x.toSparse.indices.toSet
        val ySet = y.toSparse.indices.toSet
        val intersectionSize = xSet.intersect(ySet).size.toDouble
        val unionSize = xSet.size + ySet.size - intersectionSize
        assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
        1 - intersectionSize / unionSize
      }
    
      @Since("2.1.0")
      override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
        // Since it's generated by hashing, it will be a pair of dense vectors.
        // TODO: This hashDistance function requires more discussion in SPARK-18454
        x.zip(y).map(vectorPair =>
          vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
        ).min
      }
    
      @Since("2.1.0")
      override def copy(extra: ParamMap): MinHashLSHModel = {
        val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)
        copyValues(copied, extra)
      }
    
      @Since("2.1.0")
      override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
    }
    

      

  • 相关阅读:
    使用Windows Live Writer发布日志
    下雪
    Oracle中拼出树型结构
    [转载]Javascript中最常用的55个经典技巧
    博客访问者来自15个国家和地区
    [转载]一个帐号同一时间只能一个人登录
    换了博客的皮肤
    常见的开源软件许可
    java5中的Arrays
    青花瓷
  • 原文地址:https://www.cnblogs.com/bonelee/p/11151729.html
Copyright © 2011-2022 走看看