直接上代码:
1 package horizon.graphx.util 2 3 import java.security.InvalidParameterException 4 5 import horizon.graphx.util.CollectionUtil.CollectionHelper 6 import org.apache.spark.graphx._ 7 import org.apache.spark.rdd.RDD 8 import org.apache.spark.storage.StorageLevel 9 10 import scala.collection.mutable.ArrayBuffer 11 import scala.reflect.ClassTag 12 13 /** 14 * Created by yepei.ye on 2017/1/19. 15 * Description:用于在图中为指定的节点计算这些节点的N度关系节点,输出这些节点与源节点的路径长度和节点id 16 */ 17 object GraphNdegUtil { 18 val maxNDegVerticesCount = 10000 19 val maxDegree = 1000 20 21 /** 22 * 计算节点的N度关系 23 * 24 * @param edges 25 * @param choosedVertex 26 * @param degree 27 * @tparam ED 28 * @return 29 */ 30 def aggNdegreedVertices[ED: ClassTag](edges: RDD[(VertexId, VertexId)], choosedVertex: RDD[VertexId], degree: Int): VertexRDD[Map[Int, Set[VertexId]]] = { 31 val simpleGraph = Graph.fromEdgeTuples(edges, 0, Option(PartitionStrategy.EdgePartition2D), StorageLevel.MEMORY_AND_DISK_SER, StorageLevel.MEMORY_AND_DISK_SER) 32 aggNdegreedVertices(simpleGraph, choosedVertex, degree) 33 } 34 35 def aggNdegreedVerticesWithAttr[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], choosedVertex: RDD[VertexId], degree: Int, sendFilter: (VD, VD) => Boolean = (_: VD, _: VD) => true): VertexRDD[Map[Int, Set[VD]]] = { 36 val ndegs: VertexRDD[Map[Int, Set[VertexId]]] = aggNdegreedVertices(graph, choosedVertex, degree, sendFilter) 37 val flated: RDD[Ver[VD]] = ndegs.flatMap(e => e._2.flatMap(t => t._2.map(s => Ver(e._1, s, t._1, null.asInstanceOf[VD])))).persist(StorageLevel.MEMORY_AND_DISK_SER) 38 val matched: RDD[Ver[VD]] = flated.map(e => (e.id, e)).join(graph.vertices).map(e => e._2._1.copy(attr = e._2._2)).persist(StorageLevel.MEMORY_AND_DISK_SER) 39 flated.unpersist(blocking = false) 40 ndegs.unpersist(blocking = false) 41 val grouped: RDD[(VertexId, Map[Int, Set[VD]])] = matched.map(e => (e.source, ArrayBuffer(e))).reduceByKey(_ ++= _).map(e => (e._1, e._2.map(t => (t.degree, Set(t.attr))).reduceByKey(_ ++ _).toMap)) 42 matched.unpersist(blocking = false) 43 VertexRDD(grouped) 44 } 45 46 def aggNdegreedVertices[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], 47 choosedVertex: RDD[VertexId], 48 degree: Int, 49 sendFilter: (VD, VD) => Boolean = (_: VD, _: VD) => true 50 ): VertexRDD[Map[Int, Set[VertexId]]] = { 51 if (degree < 1) { 52 throw new InvalidParameterException("度参数错误:" + degree) 53 } 54 val initVertex = choosedVertex.map(e => (e, true)).persist(StorageLevel.MEMORY_AND_DISK_SER) 55 var g: Graph[DegVertex[VD], Int] = graph.outerJoinVertices(graph.degrees)((_, old, deg) => (deg.getOrElse(0), old)) 56 .subgraph(vpred = (_, a) => a._1 <= maxDegree) 57 //去掉大节点 58 .outerJoinVertices(initVertex)((id, old, hasReceivedMsg) => { 59 DegVertex(old._2, hasReceivedMsg.getOrElse(false), ArrayBuffer((id, 0))) //初始化要发消息的节点 60 }).mapEdges(_ => 0).cache() //简化边属性 61 62 choosedVertex.unpersist(blocking = false) 63 64 var i = 0 65 var prevG: Graph[DegVertex[VD], Int] = null 66 var newVertexRdd: VertexRDD[ArrayBuffer[(VertexId, Int)]] = null 67 while (i < degree + 1) { 68 prevG = g 69 //发第i+1轮消息 70 newVertexRdd = prevG.aggregateMessages[ArrayBuffer[(VertexId, Int)]](sendMsg(_, sendFilter), (a, b) => reduceVertexIds(a ++ b)).persist(StorageLevel.MEMORY_AND_DISK_SER) 71 g = g.outerJoinVertices(newVertexRdd)((vid, old, msg) => if (msg.isDefined) updateVertexByMsg(vid, old, msg.get) else old.copy(init = false)).cache() 72 prevG.unpersistVertices(blocking = false) 73 prevG.edges.unpersist(blocking = false) 74 newVertexRdd.unpersist(blocking = false) 75 i += 1 76 } 77 newVertexRdd.unpersist(blocking = false) 78 79 val maped = g.vertices.join(initVertex).mapValues(e => sortResult(e._1)).persist(StorageLevel.MEMORY_AND_DISK_SER) 80 initVertex.unpersist() 81 g.unpersist(blocking = false) 82 VertexRDD(maped) 83 } 84 85 private case class Ver[VD: ClassTag](source: VertexId, id: VertexId, degree: Int, attr: VD = null.asInstanceOf[VD]) 86 87 private def updateVertexByMsg[VD: ClassTag](vertexId: VertexId, oldAttr: DegVertex[VD], msg: ArrayBuffer[(VertexId, Int)]): DegVertex[VD] = { 88 val addOne = msg.map(e => (e._1, e._2 + 1)) 89 val newMsg = reduceVertexIds(oldAttr.degVertices ++ addOne) 90 oldAttr.copy(init = msg.nonEmpty, degVertices = newMsg) 91 } 92 93 private def sortResult[VD: ClassTag](degs: DegVertex[VD]): Map[Int, Set[VertexId]] = degs.degVertices.map(e => (e._2, Set(e._1))).reduceByKey(_ ++ _).toMap 94 95 case class DegVertex[VD: ClassTag](var attr: VD, init: Boolean = false, degVertices: ArrayBuffer[(VertexId, Int)]) 96 97 case class VertexDegInfo[VD: ClassTag](var attr: VD, init: Boolean = false, degVertices: ArrayBuffer[(VertexId, Int)]) 98 99 private def sendMsg[VD: ClassTag](e: EdgeContext[DegVertex[VD], Int, ArrayBuffer[(VertexId, Int)]], sendFilter: (VD, VD) => Boolean): Unit = { 100 try { 101 val src = e.srcAttr 102 val dst = e.dstAttr 103 //只有dst是ready状态才接收消息 104 if (src.degVertices.size < maxNDegVerticesCount && (src.init || dst.init) && dst.degVertices.size < maxNDegVerticesCount && !isAttrSame(src, dst)) { 105 if (sendFilter(src.attr, dst.attr)) { 106 e.sendToDst(reduceVertexIds(src.degVertices)) 107 } 108 if (sendFilter(dst.attr, dst.attr)) { 109 e.sendToSrc(reduceVertexIds(dst.degVertices)) 110 } 111 } 112 } catch { 113 case ex: Exception => 114 println(s"==========error found: exception:${ex.getMessage}," + 115 s"edgeTriplet:(srcId:${e.srcId},srcAttr:(${e.srcAttr.attr},${e.srcAttr.init},${e.srcAttr.degVertices.size}))," + 116 s"dstId:${e.dstId},dstAttr:(${e.dstAttr.attr},${e.dstAttr.init},${e.dstAttr.degVertices.size}),attr:${e.attr}") 117 ex.printStackTrace() 118 throw ex 119 } 120 } 121 122 private def reduceVertexIds(ids: ArrayBuffer[(VertexId, Int)]): ArrayBuffer[(VertexId, Int)] = ArrayBuffer() ++= ids.reduceByKey(Math.min) 123 124 private def isAttrSame[VD: ClassTag](a: DegVertex[VD], b: DegVertex[VD]): Boolean = a.init == b.init && allKeysAreSame(a.degVertices, b.degVertices) 125 126 private def allKeysAreSame(a: ArrayBuffer[(VertexId, Int)], b: ArrayBuffer[(VertexId, Int)]): Boolean = { 127 val aKeys = a.map(e => e._1).toSet 128 val bKeys = b.map(e => e._1).toSet 129 if (aKeys.size != bKeys.size || aKeys.isEmpty) return false 130 131 aKeys.diff(bKeys).isEmpty && bKeys.diff(aKeys).isEmpty 132 } 133 }
其中sortResult方法里对Traversable[(K,V)]类型的集合使用了reduceByKey方法,这个方法是自行封装的,使用时需要导入,代码如下:
/** * Created by yepei.ye on 2016/12/21. * Description: */ object CollectionUtil { /** * 对具有Traversable[(K, V)]类型的集合添加reduceByKey相关方法 * * @param collection * @param kt * @param vt * @tparam K * @tparam V */ implicit class CollectionHelper[K, V](collection: Traversable[(K, V)])(implicit kt: ClassTag[K], vt: ClassTag[V]) { def reduceByKey(f: (V, V) => V): Traversable[(K, V)] = collection.groupBy(_._1).map { case (_: K, values: Traversable[(K, V)]) => values.reduce((a, b) => (a._1, f(a._2, b._2))) } /** * reduceByKey的同时,返回被reduce掉的元素的集合 * * @param f * @return */ def reduceByKeyWithReduced(f: (V, V) => V)(implicit kt: ClassTag[K], vt: ClassTag[V]): (Traversable[(K, V)], Traversable[(K, V)]) = { val reduced: ArrayBuffer[(K, V)] = ArrayBuffer() val newSeq = collection.groupBy(_._1).map { case (_: K, values: Traversable[(K, V)]) => values.reduce((a, b) => { val newValue: V = f(a._2, b._2) val reducedValue: V = if (newValue == a._2) b._2 else a._2 val reducedPair: (K, V) = (a._1, reducedValue) reduced += reducedPair (a._1, newValue) }) } (newSeq, reduced.toTraversable) } } }