zoukankan      html  css  js  c++  java
  • k近邻法的C++实现:kd树

    1.k近邻算法的思想

    给定一个训练集,对于新的输入实例,在训练集中找到与该实例最近的k个实例,这k个实例中的多数属于某个类,就把该输入实例分为这个类。

    因为要找到最近的k个实例,所以计算输入实例与训练集中实例之间的距离是关键!

    k近邻算法最简单的方法是线性扫描,这时要计算输入实例与每一个训练实例的距离,当训练集很大时,非常耗时,这种方法不可行,为了提高k近邻的搜索效率,常常考虑使用特殊的存储结构存储训练数据,以减少计算距离的次数,具体方法很多,这里介绍实现经典的kd树方法。

    2.构造kd树

    kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,kd树是二叉树

    下面举例说明:

    给定一个二维空间的数据集: T = {(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)},构造一个平衡kd树。

    • 根结点对应包含数据集T的矩形选择x(1) 轴,6个数据点的x(1) 坐标的中位数是7,以超平面x(1) = 7将空间分为左右两个子矩形(子结点)
    • 左矩形以x(2) = 4为中位数分为两个子矩形
    • 右矩形以x(2) = 6 分为两个子矩形
    • 如此递归,直到两个子区域没有实例存在时停止

    构造的kd树如下:

    3.利用kd树搜索最近邻

    输入:已构造的kd树;目标点x;

    输出:x的最近邻

    • 在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树,若目标点x的当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,直到子结点为叶结点为止。
    • 以此叶结点为“当前最近点”
    • 递归地向上回退,在每个结点进行以下操作:(a)如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为“当前最近点”;

      (b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
      一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
      心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
      个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
      近邻搜索;如果不相交,向上回退

    • 当回退到根结点时,搜索结束,最后的“当前最近点”即为x的最近邻点。

    4.C++实现

      1 #include <iostream>
      2 #include <vector>
      3 #include <algorithm>
      4 #include <string>
      5 #include <cmath>
      6 using namespace std;
      7 
      8 
      9 
     10 
     11 struct KdTree{
     12     vector<double> root;
     13     KdTree* parent;
     14     KdTree* leftChild;
     15     KdTree* rightChild;
     16     //默认构造函数
     17     KdTree(){parent = leftChild = rightChild = NULL;}
     18     //判断kd树是否为空
     19     bool isEmpty()
     20     {
     21         return root.empty();
     22     }
     23     //判断kd树是否只是一个叶子结点
     24     bool isLeaf()
     25     {
     26         return (!root.empty()) && 
     27             rightChild == NULL && leftChild == NULL;
     28     }
     29     //判断是否是树的根结点
     30     bool isRoot()
     31     {
     32         return (!isEmpty()) && parent == NULL;
     33     }
     34     //判断该子kd树的根结点是否是其父kd树的左结点
     35     bool isLeft()
     36     {
     37         return parent->leftChild->root == root;
     38     }
     39     //判断该子kd树的根结点是否是其父kd树的右结点
     40     bool isRight()
     41     {
     42         return parent->rightChild->root == root;
     43     }
     44 };
     45 
     46 int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
     47 
     48 template<typename T>
     49 vector<vector<T> > Transpose(vector<vector<T> > Matrix)
     50 {
     51     unsigned row = Matrix.size();
     52     unsigned col = Matrix[0].size();
     53     vector<vector<T> > Trans(col,vector<T>(row,0));
     54     for (unsigned i = 0; i < col; ++i)
     55     {
     56         for (unsigned j = 0; j < row; ++j)
     57         {
     58             Trans[i][j] = Matrix[j][i];
     59         }
     60     }
     61     return Trans;
     62 }
     63 
     64 template <typename T>
     65 T findMiddleValue(vector<T> vec)
     66 {
     67     sort(vec.begin(),vec.end());
     68     auto pos = vec.size() / 2;
     69     return vec[pos];
     70 }
     71 
     72 
     73 //构建kd树
     74 void buildKdTree(KdTree* tree, vector<vector<double> > data, unsigned depth)
     75 {
     76 
     77     //样本的数量
     78     unsigned samplesNum = data.size();
     79     //终止条件
     80     if (samplesNum == 0)
     81     {
     82         return;
     83     }
     84     if (samplesNum == 1)
     85     {
     86         tree->root = data[0];
     87         return;
     88     }
     89     //样本的维度
     90     unsigned k = data[0].size();
     91     vector<vector<double> > transData = Transpose(data);
     92     //选择切分属性
     93     unsigned splitAttribute = depth % k;
     94     vector<double> splitAttributeValues = transData[splitAttribute];
     95     //选择切分值
     96     double splitValue = findMiddleValue(splitAttributeValues);
     97     //cout << "splitValue" << splitValue  << endl;
     98 
     99     // 根据选定的切分属性和切分值,将数据集分为两个子集
    100     vector<vector<double> > subset1;
    101     vector<vector<double> > subset2;
    102     for (unsigned i = 0; i < samplesNum; ++i)
    103     {
    104         if (splitAttributeValues[i] == splitValue && tree->root.empty())
    105             tree->root = data[i];
    106         else
    107         {
    108             if (splitAttributeValues[i] < splitValue)
    109                 subset1.push_back(data[i]);
    110             else
    111                 subset2.push_back(data[i]);
    112         }
    113     }
    114 
    115     //子集递归调用buildKdTree函数
    116 
    117     tree->leftChild = new KdTree;
    118     tree->leftChild->parent = tree;
    119     tree->rightChild = new KdTree;
    120     tree->rightChild->parent = tree;
    121     buildKdTree(tree->leftChild, subset1, depth + 1);
    122     buildKdTree(tree->rightChild, subset2, depth + 1);
    123 }
    124 
    125 //逐层打印kd树
    126 void printKdTree(KdTree *tree, unsigned depth)
    127 {
    128     for (unsigned i = 0; i < depth; ++i)
    129         cout << "	";
    130             
    131     for (vector<double>::size_type j = 0; j < tree->root.size(); ++j)
    132         cout << tree->root[j] << ",";
    133     cout << endl;
    134     if (tree->leftChild == NULL && tree->rightChild == NULL )//叶子节点
    135         return;
    136     else //非叶子节点
    137     {
    138         if (tree->leftChild != NULL)
    139         {
    140             for (unsigned i = 0; i < depth + 1; ++i)
    141                 cout << "	";
    142             cout << " left:";
    143             printKdTree(tree->leftChild, depth + 1);
    144         }
    145             
    146         cout << endl;
    147         if (tree->rightChild != NULL)
    148         {
    149             for (unsigned i = 0; i < depth + 1; ++i)
    150                 cout << "	";
    151             cout << "right:";
    152             printKdTree(tree->rightChild, depth + 1);
    153         }
    154         cout << endl;
    155     }
    156 }
    157 
    158 
    159 //计算空间中两个点的距离
    160 double measureDistance(vector<double> point1, vector<double> point2, unsigned method)
    161 {
    162     if (point1.size() != point2.size())
    163     {
    164         cerr << "Dimensions don't match!!" ;
    165         exit(1);
    166     }
    167     switch (method)
    168     {
    169         case 0://欧氏距离
    170             {
    171                 double res = 0;
    172                 for (vector<double>::size_type i = 0; i < point1.size(); ++i)
    173                 {
    174                     res += pow((point1[i] - point2[i]), 2);
    175                 }
    176                 return sqrt(res);
    177             }
    178         case 1://曼哈顿距离
    179             {
    180                 double res = 0;
    181                 for (vector<double>::size_type i = 0; i < point1.size(); ++i)
    182                 {
    183                     res += abs(point1[i] - point2[i]);
    184                 }
    185                 return res;
    186             }
    187         default:
    188             {
    189                 cerr << "Invalid method!!" << endl;
    190                 return -1;
    191             }
    192     }
    193 }
    194 //在kd树tree中搜索目标点goal的最近邻
    195 //输入:目标点;已构造的kd树
    196 //输出:目标点的最近邻
    197 vector<double> searchNearestNeighbor(vector<double> goal, KdTree *tree)
    198 {
    199     /*第一步:在kd树中找出包含目标点的叶子结点:从根结点出发,
    200     递归的向下访问kd树,若目标点的当前维的坐标小于切分点的
    201     坐标,则移动到左子结点,否则移动到右子结点,直到子结点为
    202     叶结点为止,以此叶子结点为“当前最近点”
    203     */
    204     unsigned k = tree->root.size();//计算出数据的维数
    205     unsigned d = 0;//维度初始化为0,即从第1维开始
    206     KdTree* currentTree = tree;
    207     vector<double> currentNearest = currentTree->root;
    208     while(!currentTree->isLeaf())
    209     {
    210         unsigned index = d % k;//计算当前维
    211         if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index])
    212         {
    213             currentTree = currentTree->leftChild;
    214         }
    215         else
    216         {
    217             currentTree = currentTree->rightChild;
    218         }
    219         ++d;
    220     }
    221     currentNearest = currentTree->root;
    222 
    223     /*第二步:递归地向上回退, 在每个结点进行如下操作:
    224     (a)如果该结点保存的实例比当前最近点距离目标点更近,则以该例点为“当前最近点”
    225     (b)当前最近点一定存在于某结点一个子结点对应的区域,检查该子结点的父结点的另
    226     一子结点对应区域是否有更近的点(即检查另一子结点对应的区域是否与以目标点为球
    227     心、以目标点与“当前最近点”间的距离为半径的球体相交);如果相交,可能在另一
    228     个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着递归进行最
    229     近邻搜索;如果不相交,向上回退*/
    230 
    231     //当前最近邻与目标点的距离
    232     double currentDistance = measureDistance(goal, currentNearest, 0);
    233 
    234     //如果当前子kd树的根结点是其父结点的左孩子,则搜索其父结点的右孩子结点所代表
    235     //的区域,反之亦反
    236     KdTree* searchDistrict;
    237     if (currentTree->isLeft())
    238     {
    239         if (currentTree->parent->rightChild == NULL)
    240             searchDistrict = currentTree;
    241         else
    242             searchDistrict = currentTree->parent->rightChild;
    243     }
    244     else
    245     {
    246         searchDistrict = currentTree->parent->leftChild;
    247     }
    248 
    249     //如果搜索区域对应的子kd树的根结点不是整个kd树的根结点,继续回退搜索
    250     while (searchDistrict->parent != NULL)
    251     {
    252         //搜索区域与目标点的最近距离
    253         double districtDistance = abs(goal[(d+1)%k] - searchDistrict->parent->root[(d+1)%k]);
    254 
    255         //如果“搜索区域与目标点的最近距离”比“当前最近邻与目标点的距离”短,表明搜索
    256         //区域内可能存在距离目标点更近的点
    257         if (districtDistance < currentDistance )//&& !searchDistrict->isEmpty()
    258         {
    259 
    260             double parentDistance = measureDistance(goal, searchDistrict->parent->root, 0);
    261 
    262             if (parentDistance < currentDistance)
    263             {
    264                 currentDistance = parentDistance;
    265                 currentTree = searchDistrict->parent;
    266                 currentNearest = currentTree->root;
    267             }
    268             if (!searchDistrict->isEmpty())
    269             {
    270                 double rootDistance = measureDistance(goal, searchDistrict->root, 0);
    271                 if (rootDistance < currentDistance)
    272                 {
    273                     currentDistance = rootDistance;
    274                     currentTree = searchDistrict;
    275                     currentNearest = currentTree->root;
    276                 }
    277             }
    278             if (searchDistrict->leftChild != NULL)
    279             {
    280                 double leftDistance = measureDistance(goal, searchDistrict->leftChild->root, 0);
    281                 if (leftDistance < currentDistance)
    282                 {
    283                     currentDistance = leftDistance;
    284                     currentTree = searchDistrict;
    285                     currentNearest = currentTree->root;
    286                 }
    287             }
    288             if (searchDistrict->rightChild != NULL)
    289             {
    290                 double rightDistance = measureDistance(goal, searchDistrict->rightChild->root, 0);
    291                 if (rightDistance < currentDistance)
    292                 {
    293                     currentDistance = rightDistance;
    294                     currentTree = searchDistrict;
    295                     currentNearest = currentTree->root;
    296                 }
    297             }
    298         }//end if
    299 
    300         if (searchDistrict->parent->parent != NULL)
    301         {
    302             searchDistrict = searchDistrict->parent->isLeft()? 
    303                             searchDistrict->parent->parent->rightChild:
    304                             searchDistrict->parent->parent->leftChild;
    305         }
    306         else
    307         {
    308             searchDistrict = searchDistrict->parent;
    309         }
    310         ++d;
    311     }//end while
    312     return currentNearest;
    313 }
    314 
    315 int main()
    316 {
    317     vector<vector<double> > train(6, vector<double>(2, 0));
    318     for (unsigned i = 0; i < 6; ++i)
    319         for (unsigned j = 0; j < 2; ++j)
    320             train[i][j] = data[i][j];
    321 
    322     KdTree* kdTree = new KdTree;
    323     buildKdTree(kdTree, train, 0);
    324 
    325     printKdTree(kdTree, 0);
    326 
    327     vector<double> goal;
    328     goal.push_back(3);
    329     goal.push_back(4.5);
    330     vector<double> nearestNeighbor = searchNearestNeighbor(goal, kdTree);
    331     vector<double>::iterator beg = nearestNeighbor.begin();
    332     cout << "The nearest neighbor is: ";
    333     while(beg != nearestNeighbor.end()) cout << *beg++ << ",";
    334     cout << endl;
    335     return 0;
    336 }

    5. 运行

    下面是用上面举例构造的kd树求点(3,4.5)的最近邻:

    参考文献:李航《统计学习方法》,维基百科

  • 相关阅读:
    c#驱动操作mongodb辅助类MongoDBHelper
    c#多线程lock无效
    利用Aspose.Words将html转成pdf和将html转成word
    前后端值映射的问题
    本机部署流程详解
    Git 安装配置手册
    js对象数组(JSON) 根据某个共同字段 分组
    jquery中的$.fn的用法
    JSON.parse()与JSON.stringify()的区别
    添加编辑 时 数据不可重复验证
  • 原文地址:https://www.cnblogs.com/90zeng/p/kdtree.html
Copyright © 2011-2022 走看看