zoukankan      html  css  js  c++  java
  • kd树 C++实现

    参考:百科kd-tree

      1 /*
      2  * kdtree.h
      3  *
      4  *  Created on: Mar 3, 2017
      5  *      Author: wxquare
      6  */
      7 
      8 #ifndef KDTREE_H_
      9 #define KDTREE_H_
     10 
     11 #include <vector>
     12 #include <cmath>
     13 #include <algorithm>
     14 #include <iostream>
     15 #include <stack>
     16 
     17 template<typename T>
     18 class KdTree {
     19     struct kdNode {
     20         std::vector<T> vec;  //data
     21         //split attribute,-1 means leftNode,no split attribute
     22         int splitAttribute;
     23         kdNode* lChild;
     24         kdNode* rChild;
     25         kdNode* parent;
     26 
     27         kdNode(std::vector<T> v = { }, int split = 0, kdNode* lch = nullptr,
     28                 kdNode* rch = nullptr, kdNode* par = nullptr) :
     29                 vec(v), splitAttribute(split), lChild(lch), rChild(rch), parent(par) {}
     30     };
     31 
     32 private:
     33     kdNode *root;
     34 
     35 public:
     36     KdTree() {
     37         root = nullptr;
     38     }
     39 
     40     KdTree(std::vector<std::vector<T>>& data) {
     41         root = createKdTree(data);
     42     }
     43 
     44 
     45     //matrix transpose
     46     std::vector<std::vector<T>> transpose(std::vector<std::vector<T>>& data) {
     47         int m = data.size();
     48         int n = data[0].size();
     49         std::vector<std::vector<T>> trans(n, std::vector<T>(m, 0));
     50         for (int i = 0; i < n; i++) {
     51             for (int j = 0; j < m; j++) {
     52                 trans[i][j] = data[j][i];
     53             }
     54         }
     55         return trans;
     56     }
     57 
     58     //get variance of a vector
     59     double getVariance(std::vector<T>& vec) {
     60         int n = vec.size();
     61         double sum = 0;
     62         for (int i = 0; i < n; i++) {
     63             sum = sum + vec[i];
     64         }
     65         double avg = sum / n;
     66         sum = 0; //sum of squaNN
     67         for (int i = 0; i < n; i++) {
     68             sum += pow(vec[i] - avg, 2); //#include<cmath>
     69         }
     70         return sum / n;
     71     }
     72 
     73     //According to maximum variance get split attribute.
     74     int getSplitAttribute(const std::vector<std::vector<T>>& data) {
     75         int k = data.size();
     76         int splitAttribute = 0;
     77         double maxVar = getVariance(data[0]);
     78         for (int i = 1; i < k; i++) {
     79             double temp = getVariance(data[i]);
     80             if (temp > maxVar) {
     81                 splitAttribute = i;
     82                 maxVar = temp;
     83             }
     84         }
     85         return splitAttribute;
     86     }
     87 
     88     //find middle value
     89     T getSplitValue(std::vector<T>& vec) {
     90         std::sort(vec.begin(), vec.end());
     91         return vec[vec.size() / 2];
     92     }
     93 
     94     //compute distance of two vector
     95     static double getDistance(std::vector<T>& v1, std::vector<T>& v2) {
     96         double sum = 0;
     97         for (size_t i = 0; i < v1.size(); i++) {
     98             sum += pow(v1[i] - v2[i], 2);
     99         }
    100         return sqrt(sum) / v1.size();
    101     }
    102 
    103     kdNode* createKdTree(std::vector<std::vector<T>>& data) {
    104         //the number of samples(data)
    105         if (data.empty()) return nullptr;
    106         int n = data.size();
    107         if (n == 1) {
    108             return new kdNode(data[0], -1); //叶子节点
    109         }
    110 
    111         //get split attribute and value
    112         std::vector<std::vector<T>> data_T = transpose(data);
    113         int splitAttribute = getSplitAttribute(data_T);
    114         int splitValue = getSplitValue(data_T[splitAttribute]);
    115 
    116         //split data according splitAttribute and splitValue
    117         std::vector<std::vector<T>> left;
    118         std::vector<std::vector<T>> right;
    119 
    120         int flag = 0; //the first sample's splitValue become splitnode
    121         kdNode *splitNode;
    122         for (int i = 0; i < n; i++) {
    123             if (flag == 0 && data[i][splitAttribute] == splitValue) {
    124                 splitNode = new kdNode(data[i]);
    125                 splitNode->splitAttribute = splitAttribute;
    126                 flag = 1;
    127                 continue;
    128             }
    129             if (data[i][splitAttribute] <= splitValue) {
    130                 left.push_back(data[i]);
    131             } else {
    132                 right.push_back(data[i]);
    133             }
    134         }
    135 
    136         splitNode->lChild = createKdTree(left);
    137         splitNode->rChild = createKdTree(right);
    138         return splitNode;
    139     }
    140 
    141     //search nearest neighbor
    142     /* 参考百度百科
    143      * 从root节点开始,DFS搜索直到叶子节点,同时在stack中顺序存储已经访问的节点。
    144        如果搜索到叶子节点,当前的叶子节点被设为最近邻节点。
    145        然后通过stack回溯:
    146        如果当前点的距离比最近邻点距离近,更新最近邻节点.
    147        然后检查以最近距离为半径的圆是否和父节点的超平面相交.
    148        如果相交,则必须到父节点的另外一侧,用同样的DFS搜索法,开始检查最近邻节点。
    149        如果不相交,则继续往上回溯,而父节点的另一侧子节点都被淘汰,不再考虑的范围中.
    150        当搜索回到root节点时,搜索完成,得到最近邻节点。
    151      */
    152     std::vector<T> searchNearestNeighbor(std::vector<T>& target,kdNode* start) {
    153         std::vector<T> NN;
    154         std::stack<kdNode*> searchPath;
    155         kdNode* p = start;
    156         while (p->splitAttribute != -1) {
    157             searchPath.push(p);
    158             int splitAttribute = p->splitAttribute;
    159             if (target[splitAttribute] <= p->vec[splitAttribute]) {
    160                 p = p->lChild;
    161             } else {
    162                 p = p->rChild;
    163             }
    164         }
    165         NN = p->vec;
    166         double mindis = KdTree::getDistance(target, NN);
    167 
    168         kdNode* cur;
    169         double dis;
    170         while (!searchPath.empty()) {
    171             cur = searchPath.top();
    172             searchPath.pop();
    173             dis = KdTree::getDistance(target, cur->vec);
    174             if (dis < mindis) {
    175                 mindis = dis;
    176                 NN = cur->vec;
    177                 //判断以target为中心,以dis为半径的球是否和节点的超平面相交
    178                 if (cur->vec[cur->splitAttribute]
    179                         >= target[cur->splitAttribute] - dis
    180                         && cur->vec[cur->splitAttribute]
    181                                 <= target[cur->splitAttribute] + dis) {
    182                     std::vector<T> nn = searchNearestNeighbor(target,
    183                             cur->lChild);
    184                     if (KdTree::getDistance(target, nn)
    185                             < KdTree::getDistance(target, NN)) {
    186                         NN = nn;
    187                     }
    188                 }
    189             }
    190         }
    191         return NN;
    192     }
    193 
    194     std::vector<T> searchNearestNeighbor(std::vector<T>& target) {
    195         std::vector<T> NN;
    196         NN = searchNearestNeighbor(target, root);
    197         return NN;
    198     }
    199 
    200     void print(kdNode* root) {
    201         std::cout << "[";
    202         if (root->lChild) {
    203             std::cout << "left:";
    204             print(root->lChild);
    205         }
    206 
    207         if (root) {
    208             std::cout << "(";
    209             for (size_t i = 0; i < root->vec.size(); i++) {
    210                 std::cout << root->vec[i];
    211                 if (i != (root->vec.size() - 1))
    212                     std::cout << ",";
    213             }
    214             std::cout << ")";
    215         }
    216 
    217         if (root->rChild) {
    218             std::cout << "right:";
    219             print(root->rChild);
    220         }
    221         std::cout << "]";
    222     }
    223 
    224 };
    225 
    226 #endif /* KDTREE_H_ */
  • 相关阅读:
    CSS overflow 隐藏属性
    CSS visibility 隐藏属性
    多线程中的detach
    多线程中join的解释(转)
    lib 和 dll 的区别、生成以及使用详解:(包括变量,函数,类导出3种情形)(转)
    堆和栈的区别
    ZMQ相关
    不同类型的指针加减(就是向前或向后移动)[转]
    memset函数及其用法,C语言memset函数详解
    zmq中的router和dealer
  • 原文地址:https://www.cnblogs.com/wxquare/p/6497302.html
Copyright © 2011-2022 走看看