zoukankan      html  css  js  c++  java
  • KD Tree算法

    参考:http://blog.csdn.net/v_july_v/article/details/8203674

    #!/user/bin/env python
    # -*- coding:utf8 -*-
    
    __author__ = 'zky@msn.cn'
    
    import sys
    import numpy
    import heapq
    import Queue
    
    class KDNode(object):
        def __init__(self, name, feature):
            self.name = name
            self.ki = -1
            self.is_leaf = False
            self.feature = feature
            self.kd_left = None
            self.kd_right = None
    
        def traverse(self, seq, order='in'):
            if order == 'in':
                if self.kd_left:
                    self.kd_left.traverse(seq, order)
                seq.append(self)
                if self.kd_right:
                    self.kd_right.traverse(seq, order)
            elif order == 'pre':
                seq.append(self)
                if self.kd_left:
                    self.kd_left.traverse(seq, order)
                if self.kd_right:
                    self.kd_right.traverse(seq, order)
            elif order == 'post':
                if self.kd_left:
                    self.kd_left.traverse(seq, order)
                if self.kd_right:
                    self.kd_right.traverse(seq, order)
                seq.append(self)
            else:
                assert(False)
    
    class NodeDistance(object):
        def __init__(self, kd_node, distance):
            self.kd_node = kd_node
            self.distance = distance
    
        # here i use a reversed result, because heapq can support only min heap
        def __cmp__(self, other):
            ret = other.distance - self.distance
            if ret > 0:
                return 1
            elif ret < 0:
                return -1
            else:
                return 0
    
    def euclidean_distance(node1, node2):
        assert len(node1.feature) == len(node2.feature)
        sum = 0
        for i in xrange(len(node1.feature)):
            sum += numpy.square(node1.feature[i] - node2.feature[i])
        return numpy.sqrt(sum)
    
    class KDTree(object):
        # n is num of dimension
        def __init__(self, nodes, n):
            self.root = self.build_kdtree(nodes, n)
            self.n = n
    
        def build_kdtree(self, nodes, n):
            if len(nodes) == 0:
                return None
            max_var = 0
            index = 0
            for i in xrange(n):
                features_n = map(lambda node : node.feature[i], nodes)
                var = numpy.var(features_n)
                if var > max_var:
                    max_var = var
                    index = i
            sorted_nodes = sorted(nodes, key=lambda node: node.feature[index])
            mid = len(sorted_nodes)/2
            root = sorted_nodes[mid]
            left_nodes = sorted_nodes[:mid]
            right_nodes = sorted_nodes[mid+1:]
    
            root.ki = index
            if len(left_nodes) == 0 and len(right_nodes) == 0:
                root.is_leaf = True
            root.kd_left = self.build_kdtree(left_nodes, n)
            root.kd_right = self.build_kdtree(right_nodes, n)
            return root
    
        def traverse_kdtree(self, order='in'):
            seq = []
            self.root.traverse(seq, order)
            print map(lambda n : n.name, seq)
    
        # return a list of NodeDistance sorded by distance
        def kdtree_bbf_knn(self, target, k):
            if len(target.feature) != self.n:
                return None
            knn = []
            priority_queue = Queue.LifoQueue()
            priority_queue.put(self.root)
            while not priority_queue.empty():
                expl = priority_queue.get()
                while expl:
                    ki = expl.ki
                    kv = expl.feature[ki]
    
                    if expl.name != target.name: # ignore target node itself
                        # save a maybe result
                        distance = euclidean_distance(expl, target)
                        nd = NodeDistance(expl, distance)
                        assert len(knn) <= k
                        if len(knn) == k:
                            if distance < knn[0].distance:
                                heapq.heapreplace(knn, nd)
                        else: # len(knn) < k
                            heapq.heappush(knn, nd)
    
                    unexpl = None
                    # find next expl
                    if target.feature[ki] <= kv: # left
                        unexpl = expl.kd_right
                        expl = expl.kd_left
                    else:
                        unexpl = expl.kd_left
                        expl = expl.kd_right
    
                    # ignore nodes over a long distance bin
                    if unexpl:
                        # save a maybe next expl 
                        if len(knn) < k:
                            priority_queue.put(unexpl)
                        elif (len(knn) == k) and (abs(kv - target.feature[ki]) < knn[0].distance):
                            priority_queue.put(unexpl)
            ret = []
            for i in xrange(len(knn)):
                node = heapq.heappop(knn)
                ret.insert(0, node)
            return ret
    
    if __name__ == '__main__':
        f1 = [7, 2]
        f2 = [5, 4]
        f3 = [9, 6]
        f4 = [2, 3]
        f5 = [4, 7]
        f6 = [8, 1]
        fx = [2, 4.5]
        n1 = KDNode('f1', f1)
        n2 = KDNode('f2', f2)
        n3 = KDNode('f3', f3)
        n4 = KDNode('f4', f4)
        n5 = KDNode('f5', f5)
        n6 = KDNode('f6', f6)
        nx = KDNode('fx', fx)
    
        n1_distance = NodeDistance(n4, 1.5)
        n2_distance = NodeDistance(n5, 3.2)
        n3_distance = NodeDistance(n2, 3.04)
        assert n1_distance > n2_distance
        assert n1_distance > n3_distance
        assert n2_distance < n3_distance
    
        tree = KDTree([n1, n2, n3, n4, n5, n6, nx], 2)
        tree.traverse_kdtree('in')
        knn = tree.kdtree_bbf_knn(nx, 3)
        print map(lambda n : (n.kd_node.name, n.distance), knn)
    
  • 相关阅读:
    Android AHandle AMessage
    android java 与C 通过 JNI双向通信
    android 系统给应用的jar
    UE4 unreliable 同步问题
    UE4 difference between servertravel and openlevel(多人游戏的关卡切换)
    UE4 Run On owing Client解析(RPC测试)
    UE4 TSubclassOf VS Native Pointer
    UE4 内容示例网络同步Learn
    UE4 多人FPS VR游戏制作笔记
    UE4 分层材质 Layerd Materials
  • 原文地址:https://www.cnblogs.com/ZisZ/p/6086253.html
Copyright © 2011-2022 走看看