zoukankan      html  css  js  c++  java
  • [Machine Learning]kNN代码实现(Kd tree)

    具体描述见《统计学习方法》第三章。

      1 //
      2 //  main.cpp
      3 //  kNN
      4 //
      5 //  Created by feng on 15/10/24.
      6 //  Copyright © 2015年 ttcn. All rights reserved.
      7 //
      8 
      9 #include <iostream>
     10 #include <vector>
     11 #include <algorithm>
     12 #include <cmath>
     13 using namespace std;
     14 
     15 template<typename T>
     16 struct KdTree {
     17     // ctor
     18     KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {}
     19     
     20     // KdTree是否为空
     21     bool isEmpty() { return root.empty(); }
     22     
     23     // KdTree是否为叶子节点
     24     bool isLeaf() { return !root.empty() && !leftChild && !rightChild;}
     25     
     26     // KdTree是否为根节点
     27     bool isRoot() { return !isEmpty() && !parent;}
     28     
     29     // 判断KdTree是否为根节点的左儿子
     30     bool isLeft() { return parent->leftChild->root == root; }
     31     
     32     // 判断KdTree是否为根节点的右儿子
     33     bool isRight() { return parent->rightChild->root == root; }
     34     
     35     // 存放根节点的数据
     36     vector<T> root;
     37     
     38     // 父节点
     39     KdTree<T> *parent;
     40     
     41     // 左儿子
     42     KdTree<T> *leftChild;
     43     
     44     // 右儿子
     45     KdTree<T> *rightChild;
     46 };
     47 
     48 
     49 /**
     50  *  矩阵转置
     51  *
     52  *  @param matrix 原矩阵
     53  *
     54  *  @return 原矩阵的转置矩阵
     55  */
     56 template<typename T>
     57 vector<vector<T>> transpose(const vector<vector<T>> &matrix) {
     58     size_t rows = matrix.size();
     59     size_t cols = matrix[0].size();
     60     vector<vector<T>> trans(cols, vector<T>(rows, 0));
     61     for (size_t i = 0; i < cols; ++i) {
     62         for (size_t j = 0; j < rows; ++j) {
     63             trans[i][j] = matrix[j][i];
     64         }
     65     }
     66     
     67     return trans;
     68 }
     69 
     70 /**
     71  *  找中位数
     72  *
     73  *  @param vec 数组
     74  *
     75  *  @return 数组中的中位数
     76  */
     77 template<typename T>
     78 T findMiddleValue(vector<T> vec) {
     79     sort(vec.begin(), vec.end());
     80     size_t pos = vec.size() / 2;
     81     return vec[pos];
     82 }
     83 
     84 /**
     85  *  递归构造KdTree
     86  *
     87  *  @param tree  KdTree根节点
     88  *  @param data  数据矩阵
     89  *  @param depth 当前节点深度
     90  *
     91  *  @return void
     92  */
     93 template<typename T>
     94 void buildKdTree(KdTree<T> *tree, vector<vector<T>> &data, size_t depth) {
     95     // 输入数据个数
     96     size_t samplesNum = data.size();
     97     
     98     if (samplesNum == 0) {
     99         return;
    100     }
    101     
    102     if (samplesNum == 1) {
    103         tree->root = data[0];
    104         return;
    105     }
    106     
    107     // 每一个输入数据的维度,属性个数
    108     size_t k = data[0].size();
    109     vector<vector<T>> transData = transpose(data);
    110     
    111     // 找到当前切分点
    112     size_t splitAttributeIndex = depth % k;
    113     vector<T> splitAttributes = transData[splitAttributeIndex];
    114     T splitValue = findMiddleValue(splitAttributes);
    115     
    116     vector<vector<T>> leftSubSet;
    117     vector<vector<T>> rightSubset;
    118     
    119     for (size_t i = 0; i < samplesNum; ++i) {
    120         if (splitAttributes[i] == splitValue && tree->isEmpty()) {
    121             tree->root = data[i];
    122         } else if (splitAttributes[i] < splitValue) {
    123             leftSubSet.push_back(data[i]);
    124         } else {
    125             rightSubset.push_back(data[i]);
    126         }
    127     }
    128     
    129     tree->leftChild = new KdTree<T>;
    130     tree->leftChild->parent = tree;
    131     tree->rightChild = new KdTree<T>;
    132     tree->rightChild->parent = tree;
    133     buildKdTree(tree->leftChild, leftSubSet, depth + 1);
    134     buildKdTree(tree->rightChild, rightSubset, depth + 1);
    135 }
    136 
    137 /**
    138  *  递归打印KdTree
    139  *
    140  *  @param tree  KdTree
    141  *  @param depth 当前深度
    142  *
    143  *  @return void
    144  */
    145 template<typename T>
    146 void printKdTree(const KdTree<T> *tree, size_t depth) {
    147     for (size_t i = 0; i < depth; ++i) {
    148         cout << "	";
    149     }
    150     
    151     for (size_t i = 0; i < tree->root.size(); ++i) {
    152         cout << tree->root[i] << " ";
    153     }
    154     cout << endl;
    155     
    156     if (tree->leftChild == nullptr && tree->rightChild == nullptr) {
    157         return;
    158     } else {
    159         if (tree->leftChild) {
    160             for (int i = 0; i < depth + 1; ++i) {
    161                 cout << "	";
    162             }
    163             cout << "left : ";
    164             printKdTree(tree->leftChild, depth + 1);
    165         }
    166         
    167         cout << endl;
    168         
    169         if (tree->rightChild) {
    170             for (size_t i = 0; i < depth + 1; ++i) {
    171                 cout << "	";
    172             }
    173             cout << "right : ";
    174             printKdTree(tree->rightChild, depth + 1);
    175         }
    176         cout << endl;
    177     }
    178 }
    179 
    180 /**
    181  *  节点之间的欧氏距离
    182  *
    183  *  @param p1 节点1
    184  *  @param p2 节点2
    185  *
    186  *  @return 节点之间的欧式距离
    187  */
    188 template<typename T>
    189 T calDistance(const vector<T> &p1, const vector<T> &p2) {
    190     T res = 0;
    191     for (size_t i = 0; i < p1.size(); ++i) {
    192         res += pow(p1[i] - p2[i], 2);
    193     }
    194     
    195     return res;
    196 }
    197 
    198 /**
    199  *  搜索目标节点的最近邻
    200  *
    201  *  @param tree KdTree
    202  *  @param goal 待分类的节点
    203  *
    204  *  @return 最近邻节点
    205  */
    206 template <typename T>
    207 vector<T> searchNearestNeighbor(KdTree<T> *tree, const vector<T> &goal ) {
    208     // 节点数属性个数
    209     size_t k = tree->root.size();
    210     // 划分的索引
    211     size_t d = 0;
    212     KdTree<T> *currentTree = tree;
    213     vector<T> currentNearest = currentTree->root;
    214     // 找到目标节点的最叶节点
    215     while (!currentTree->isLeaf()) {
    216         size_t index = d % k;
    217         if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) {
    218             currentTree = currentTree->leftChild;
    219         } else {
    220             currentTree = currentTree->rightChild;
    221         }
    222         
    223         ++d;
    224     }
    225     currentNearest = currentTree->root;
    226     T currentDistance = calDistance(goal, currentTree->root);
    227     
    228     KdTree<T> *searchDistrict;
    229     if (currentTree->isLeft()) {
    230         if (!(currentTree->parent->rightChild)) {
    231             searchDistrict = currentTree;
    232         } else {
    233             searchDistrict = currentTree->parent->rightChild;
    234         }
    235     } else {
    236         searchDistrict = currentTree->parent->leftChild;
    237     }
    238     
    239     while (!(searchDistrict->parent)) {
    240         T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]);
    241         
    242         if (districtDistance < currentDistance) {
    243             T parentDistance = calDistance(goal, searchDistrict->parent->root);
    244             
    245             if (parentDistance < currentDistance) {
    246                 currentDistance = parentDistance;
    247                 currentTree = searchDistrict->parent;
    248                 currentNearest = currentTree->root;
    249             }
    250             
    251             if (!searchDistrict->isEmpty()) {
    252                 T rootDistance = calDistance(goal, searchDistrict->root);
    253                 if (rootDistance < currentDistance) {
    254                     currentDistance = rootDistance;
    255                     currentTree = searchDistrict;
    256                     currentNearest = currentTree->root;
    257                 }
    258             }
    259             
    260             if (!(searchDistrict->leftChild)) {
    261                 T leftDistance = calDistance(goal, searchDistrict->leftChild->root);
    262                 if (leftDistance < currentDistance) {
    263                     currentDistance = leftDistance;
    264                     currentTree = searchDistrict;
    265                     currentNearest = currentTree->root;
    266                 }
    267             }
    268             
    269             if (!(searchDistrict->rightChild)) {
    270                 T rightDistance = calDistance(goal, searchDistrict->rightChild->root);
    271                 if (rightDistance < currentDistance) {
    272                     currentDistance = rightDistance;
    273                     currentTree = searchDistrict;
    274                     currentNearest = currentTree->root;
    275                 }
    276             }
    277             
    278         }
    279         
    280         if (!(searchDistrict->parent->parent)) {
    281             searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild;
    282         } else {
    283             searchDistrict = searchDistrict->parent;
    284         }
    285         ++d;
    286     }
    287     
    288     return currentNearest;
    289 }
    290 
    291 int main(int argc, const char * argv[]) {
    292     vector<vector<double>> trainDataSet{{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
    293     KdTree<double> *kdTree = new KdTree<double>;
    294     buildKdTree(kdTree, trainDataSet, 0);
    295     printKdTree(kdTree, 0);
    296     
    297     vector<double> goal{3, 4.5};
    298     vector<double> nearestNeighbor = searchNearestNeighbor(kdTree, goal);
    299     
    300     for (auto i : nearestNeighbor) {
    301         cout << i << " ";
    302     }
    303     cout << endl;
    304     
    305     return 0;
    306 }
  • 相关阅读:
    游戏系统开发笔记(八)——场景对象管理
    dynomite:高可用多数据中心同步
    VBS错误代码释义
    VBScript 内置函数
    在oracle中,select语句查询字段中非纯数字值
    ASP里面令人震撼地自定义Debug类(VBScript)
    调试 ASP 程序脚本
    多文档界面的实现(DotNetBar的superTabControl)
    C#利用tabControl控件实现多窗体嵌入及关闭
    使用MDI窗体实现多窗口效果
  • 原文地址:https://www.cnblogs.com/skycore/p/4908873.html
Copyright © 2011-2022 走看看