zoukankan      html  css  js  c++  java
  • c++实现kd树

      1 #ifndef _KD_TREE_H_
      2 #define _KD_TREE_H_
      3 
      4 #include <memory>
      5 #include <vector>
      6 #include <algorithm>
      7 #include <iostream>
      8 #include <functional>
      9 #include <iomanip>
     10 #include <stack>
     11 #include <array>
     12 #include <cfloat>
     13 #include <cmath>
     14 
     15 namespace zstd
     16 {
     17     struct threeD_node
     18     {
     19         double value[3];//value[0] = x, value[1] = y, value[2] = z
     20         threeD_node()
     21         {
     22             value[0] = 0.0;
     23             value[1] = 0.0;
     24             value[2] = 0.0;
     25         }
     26         threeD_node(double x, double y, double z)
     27         {
     28             value[0] = x;
     29             value[1] = y;
     30             value[2] = z;
     31         }
     32     };
     33     struct sort_for_threeD_node
     34     {
     35         int dimension;
     36         sort_for_threeD_node(int d) :dimension(d){}
     37         bool operator ()(const threeD_node& lhs, const threeD_node& rhs)
     38         {
     39             if (dimension == 0)
     40                 return lhs.value[0] < rhs.value[0];
     41             else if (dimension == 1)
     42                 return lhs.value[1] < rhs.value[1];
     43             else if (dimension == 2)
     44                 return lhs.value[2] < rhs.value[2];
     45             else
     46                 std::cerr << "error in sort_for_threeD_node"<< std::endl;
     47             return false;
     48         }
     49     };
     50 
     51     struct kd_node
     52     {
     53         double value[3];//value[0] = x, value[1] = y, value[2] = z
     54         int kv;//0=x, 1=y, 2=z
     55         bool is_leaf;
     56         kd_node  *left, *right;
     57         kd_node()
     58         {
     59             value[0] = value[1] = value[2] = 0.0;
     60             kv = -1;
     61             is_leaf = false;
     62             left = nullptr;
     63             right = nullptr;
     64         }
     65         kd_node(const kd_node& node)
     66         {
     67             value[0] = node.value[0];
     68             value[1] = node.value[1];
     69             value[2] = node.value[2];
     70             kv = node.kv;
     71             is_leaf = node.is_leaf;
     72             left = node.left;
     73             right = node.right;
     74         }
     75         kd_node& operator = (const kd_node& node)
     76         {
     77             value[0] = node.value[0];
     78             value[1] = node.value[1];
     79             value[2] = node.value[2];
     80             kv = node.kv;
     81             is_leaf = node.is_leaf;
     82             left = node.left;
     83             right = node.right;
     84 
     85             return *this;
     86         }
     87     };
     88     class kd_tree
     89     {
     90     private:
     91         std::shared_ptr<kd_node> root;
     92         std::vector<threeD_node>& vec_ref;
     93         const int k = 3;
     94         const int cspace = 4;
     95     private:
     96         int get_dimension(int n) const
     97         {
     98             return n % k;
     99         }
    100         void sort_by_dimension(std::vector<threeD_node>& v, int dimension, int l, int r);
    101         kd_node* build_tree(int left, int right, kd_node* sp_node, int dimension);
    102         void _print_tree(kd_node* sp, bool left, int space);
    103 
    104         double distance(const kd_node& lhs, const threeD_node& rhs);
    105     public:
    106         explicit kd_tree(std::vector<threeD_node>&);
    107         kd_tree(const kd_tree&) = delete;
    108         kd_tree operator = (const kd_tree&) = delete;
    109         ~kd_tree(){};
    110 
    111         void print_tree();
    112         std::vector<threeD_node> find_k_nearest(int k, const threeD_node& D);
    113     };
    114     void kd_tree::sort_by_dimension(std::vector<threeD_node>& v, int dimension, int l, int r)
    115     {
    116         sort_for_threeD_node s(dimension);
    117         std::sort(v.begin()+l, v.begin()+r, s);
    118     }
    119     kd_tree::kd_tree(std::vector<threeD_node>& v) :vec_ref(v)
    120     {
    121         if (vec_ref.empty())
    122             root = nullptr;
    123         else
    124         {
    125             root = std::make_shared<kd_node>();
    126             int dimension = 0;
    127             sort_by_dimension(vec_ref, dimension, 0, vec_ref.size());
    128             int mid = vec_ref.size() / 2;
    129             root->value[0] = vec_ref[mid].value[0];
    130             root->value[1] = vec_ref[mid].value[1];
    131             root->value[2] = vec_ref[mid].value[2];
    132             root->kv = dimension;
    133             if (vec_ref.size() == 1)//root is leaf
    134             {
    135                 root->left = nullptr;
    136                 root->right = nullptr;
    137                 root->is_leaf = true;
    138             }
    139             else
    140             {
    141                 root->is_leaf = false;
    142                 root->left = build_tree(0, mid - 1, root->left, get_dimension(dimension + 1));
    143                 root->right = build_tree(mid + 1, vec_ref.size() - 1, root->right, get_dimension(dimension + 1));
    144             }
    145         }
    146     }
    147     kd_node* kd_tree::build_tree(int left, int right, kd_node* sp_node, int dimension)
    148     {
    149         dimension = get_dimension(dimension);
    150         sort_by_dimension(vec_ref, dimension, left, right + 1);
    151 
    152         if(left == right)//leaf
    153         {
    154             sp_node = new kd_node();
    155             sp_node->value[0] = vec_ref[left].value[0];
    156             sp_node->value[1] = vec_ref[left].value[1];
    157             sp_node->value[2] = vec_ref[left].value[2];
    158             sp_node->kv = dimension;
    159             sp_node->is_leaf = true;
    160             sp_node->left = nullptr;
    161             sp_node->right = nullptr;
    162 
    163             return sp_node;
    164         }
    165         else if (left < right)
    166         {
    167             int mid = left + (right - left) / 2;
    168             sp_node = new kd_node();
    169             sp_node->value[0] = vec_ref[mid].value[0];
    170             sp_node->value[1] = vec_ref[mid].value[1];
    171             sp_node->value[2] = vec_ref[mid].value[2];
    172             sp_node->kv = dimension;
    173             sp_node->is_leaf = false;
    174             sp_node->left = nullptr;
    175             sp_node->right = nullptr;
    176 
    177             sp_node->left = build_tree(left, mid - 1, sp_node->left, get_dimension(dimension + 1));
    178             sp_node->right = build_tree(mid + 1, right, sp_node->right, get_dimension(dimension + 1));
    179 
    180             return sp_node;
    181         }
    182         return nullptr;
    183     }
    184     void kd_tree::_print_tree(kd_node* sp, bool left, int space)
    185     {
    186         if (sp != nullptr)
    187         {
    188             _print_tree(sp->right, false, space + cspace);
    189             std::cout << std::setw(space);
    190             std::cout << "(" <<
    191                 sp->value[0] << ", " <<
    192                 sp->value[1] << ", " <<
    193                 sp->value[2] << ")";
    194             if (left)
    195                 std::cout << "left";
    196             else
    197                 std::cout << "right";
    198             if (sp->is_leaf)
    199                 std::cout << "------leaf";
    200             std::cout << std::endl;
    201             _print_tree(sp->left, true, space + cspace);
    202         }
    203         else
    204             std::cout << std::endl;
    205     }
    206     void kd_tree::print_tree()
    207     {
    208         std::cout << "kd_tree : " << std::endl;
    209         if (root != nullptr)
    210         {
    211             int space = 0;
    212             _print_tree(root->right, false, space + cspace);
    213             std::cout << "(" << 
    214                 root->value[0] << ", " << 
    215                 root->value[1] << ", " << 
    216                 root->value[2] << ")root" << std::endl;
    217             _print_tree(root->left, true, space + cspace);
    218         }
    219     }
    220     double kd_tree::distance(const kd_node& lhs, const threeD_node& rhs)
    221     {
    222         double v0 = lhs.value[0] - rhs.value[0];
    223         double v1 = lhs.value[1] - rhs.value[1];
    224         double v2 = lhs.value[2] - rhs.value[2];
    225         return sqrt(v0 * v0 + v1 * v1 + v2 * v2);
    226     }
    227     std::vector<threeD_node> kd_tree::find_k_nearest(int ks, const threeD_node& D)
    228     {
    229         std::vector<threeD_node> res;
    230         const kd_node *ptr_kd_node;
    231         if (static_cast<std::size_t>(ks) > vec_ref.size())
    232             return res;
    233         std::stack<kd_node> s;
    234         struct pair
    235         {
    236             double distance;
    237             kd_node node;
    238             pair() :distance(DBL_MAX), node(){ }
    239             bool operator < (const pair& rhs)
    240             {
    241                 return distance < rhs.distance;
    242             }
    243         };
    244         std::unique_ptr<pair[]> ptr_pair(new pair[ks]);
    245         //pair *ptr_pair = new pair[ks]();
    246         if (!ptr_pair)
    247             exit(-1);
    248 
    249         if (!root)//the tree is empty
    250             return std::vector<threeD_node>();
    251         else
    252         {
    253             if (D.value[root->kv] < root->value[root->kv])
    254             {
    255                 s.push(*root);
    256                 ptr_kd_node = root->left;
    257             }
    258             else
    259             {
    260                 s.push(*root);
    261                 ptr_kd_node = root->right;
    262             }
    263             while (ptr_kd_node != nullptr)
    264             {
    265                 if (D.value[ptr_kd_node->kv] < ptr_kd_node->value[ptr_kd_node->kv])
    266                 {
    267                     s.push(*ptr_kd_node);
    268                     ptr_kd_node = ptr_kd_node->left;
    269                 }
    270                 else
    271                 {
    272                     s.push(*ptr_kd_node);
    273                     ptr_kd_node = ptr_kd_node->right;
    274                 }
    275             }
    276             
    277             while (!s.empty())
    278             {
    279                 kd_node popped_kd_node;//±£´æ×îеĴÓÕ»ÖÐpop³öµÄkd_node
    280                 popped_kd_node = s.top();
    281                 s.pop();
    282                 double dist = distance(popped_kd_node, D);
    283                 std::sort(&ptr_pair[0], &ptr_pair[ks]);
    284                 if (dist < ptr_pair[ks-1].distance)
    285                 {
    286                     ptr_pair[ks-1].distance = dist;
    287                     ptr_pair[ks-1].node = popped_kd_node;
    288                 }
    289 
    290                 if (abs(D.value[popped_kd_node.kv] - popped_kd_node.value[popped_kd_node.kv])
    291                         >= dist)//Ô²²»ºÍpopped_kd_nodeµÄÁíÒ»°ëÇøÓòÏཻ
    292                     continue;
    293                 else//Ô²ºÍpopped_kd_nodeµÄÁíÒ»°ëÇøÓòÏཻ
    294                 {
    295                     if (D.value[popped_kd_node.kv] < popped_kd_node.value[popped_kd_node.kv])//right
    296                     {
    297                         kd_node *ptr = popped_kd_node.right;
    298                         while (ptr != nullptr)
    299                         {    
    300                             s.push(*ptr);
    301                             if (D.value[ptr->kv] < ptr->value[ptr->kv])
    302                                 ptr = ptr->left;
    303                             else
    304                                 ptr = ptr->right;
    305                         }
    306                     }
    307                     else//left
    308                     {
    309                         kd_node *ptr = popped_kd_node.left;
    310                         while (ptr != nullptr)
    311                         {
    312                             s.push(*ptr);
    313                             if (D.value[ptr->kv] < ptr->value[ptr->kv])
    314                                 ptr = ptr->left;
    315                             else
    316                                 ptr = ptr->right;
    317                         }
    318                     }
    319                 }
    320             }//end of while
    321             for(int i = 0; i != ks; ++i)
    322                 res.push_back(threeD_node(ptr_pair[i].node.value[0], 
    323                             ptr_pair[i].node.value[1], ptr_pair[i].node.value[2]));
    324         }//end of else
    325         //delete ptr_pair;
    326         return res;
    327     }
    328 
    329 }//end of namespace zstd
    330 
    331 #endif
     1 #include <string>
     2 #include <iostream>
     3 #include <new>
     4 #include <fstream>
     5 #include <vector>
     6 #include <algorithm>
     7 #include <ctime>
     8 
     9 #include "trie_tree.h"
    10 #include "kd_tree.h"
    11 
    12 int main()
    13 {
    14     std::vector<zstd::threeD_node> v, res;
    15     v.push_back(zstd::threeD_node(2, 3, 1));//14
    16     v.push_back(zstd::threeD_node(5, 4, 7));//90
    17     v.push_back(zstd::threeD_node(9, 6, 9));//198
    18     v.push_back(zstd::threeD_node(4, 7, 2));//69
    19     v.push_back(zstd::threeD_node(8, 1, 5));//90
    20     v.push_back(zstd::threeD_node(7, 2, 0));//53
    21     v.push_back(zstd::threeD_node(8, 8, 8));//192
    22     v.push_back(zstd::threeD_node(1, 2, 3));//14
    23     v.push_back(zstd::threeD_node(5, 2, 1));//30
    24     v.push_back(zstd::threeD_node(12, 23, 0));//673
    25     v.push_back(zstd::threeD_node(10, 0, 2));//104
    26     std::cout << "size: " << v.size() << std::endl;
    27     zstd::kd_tree tree(v);
    28     tree.print_tree();
    29     res = tree.find_k_nearest(11, zstd::threeD_node(0, 0, 0));
    30     std::cout << "-------" << std::endl;
    31     std::cout << "离点(0,0,0)最近的点依次是:" << std::endl;
    32     for (auto i : res)
    33     {
    34         std::cout << "(" << i.value[0] << ", " << i.value[1] << ", " << i.value[2] << ")" << std::endl;
    35     }
    36     system("pause");
    37     return 0;
    38 }

  • 相关阅读:
    poj1014 Dividing (多重背包)
    HDOJ 1316 How Many Fibs?
    最大字串和
    WHY IE AGAIN?
    Codeforces Round #143 (Div. 2) (ABCD 思维场)
    自用组帧工具
    菜鸟学EJB(二)——在同一个SessionBean中使用@Remote和@Local
    shell 块注释
    检测到在集成的托管管道模式下不适用的 ASP.NET 设置的解决方法
    Windows Myeclipse 10 安装 Perl 插件
  • 原文地址:https://www.cnblogs.com/zxh1210603696/p/3491254.html
Copyright © 2011-2022 走看看