具体描述见《统计学习方法》第三章。
1 // 2 // main.cpp 3 // kNN 4 // 5 // Created by feng on 15/10/24. 6 // Copyright © 2015年 ttcn. All rights reserved. 7 // 8 9 #include <iostream> 10 #include <vector> 11 #include <algorithm> 12 #include <cmath> 13 using namespace std; 14 15 template<typename T> 16 struct KdTree { 17 // ctor 18 KdTree():parent(nullptr), leftChild(nullptr), rightChild(nullptr) {} 19 20 // KdTree是否为空 21 bool isEmpty() { return root.empty(); } 22 23 // KdTree是否为叶子节点 24 bool isLeaf() { return !root.empty() && !leftChild && !rightChild;} 25 26 // KdTree是否为根节点 27 bool isRoot() { return !isEmpty() && !parent;} 28 29 // 判断KdTree是否为根节点的左儿子 30 bool isLeft() { return parent->leftChild->root == root; } 31 32 // 判断KdTree是否为根节点的右儿子 33 bool isRight() { return parent->rightChild->root == root; } 34 35 // 存放根节点的数据 36 vector<T> root; 37 38 // 父节点 39 KdTree<T> *parent; 40 41 // 左儿子 42 KdTree<T> *leftChild; 43 44 // 右儿子 45 KdTree<T> *rightChild; 46 }; 47 48 49 /** 50 * 矩阵转置 51 * 52 * @param matrix 原矩阵 53 * 54 * @return 原矩阵的转置矩阵 55 */ 56 template<typename T> 57 vector<vector<T>> transpose(const vector<vector<T>> &matrix) { 58 size_t rows = matrix.size(); 59 size_t cols = matrix[0].size(); 60 vector<vector<T>> trans(cols, vector<T>(rows, 0)); 61 for (size_t i = 0; i < cols; ++i) { 62 for (size_t j = 0; j < rows; ++j) { 63 trans[i][j] = matrix[j][i]; 64 } 65 } 66 67 return trans; 68 } 69 70 /** 71 * 找中位数 72 * 73 * @param vec 数组 74 * 75 * @return 数组中的中位数 76 */ 77 template<typename T> 78 T findMiddleValue(vector<T> vec) { 79 sort(vec.begin(), vec.end()); 80 size_t pos = vec.size() / 2; 81 return vec[pos]; 82 } 83 84 /** 85 * 递归构造KdTree 86 * 87 * @param tree KdTree根节点 88 * @param data 数据矩阵 89 * @param depth 当前节点深度 90 * 91 * @return void 92 */ 93 template<typename T> 94 void buildKdTree(KdTree<T> *tree, vector<vector<T>> &data, size_t depth) { 95 // 输入数据个数 96 size_t samplesNum = data.size(); 97 98 if (samplesNum == 0) { 99 return; 100 } 101 102 if (samplesNum == 1) { 103 tree->root = data[0]; 104 return; 105 } 106 107 // 每一个输入数据的维度,属性个数 108 size_t k = data[0].size(); 109 vector<vector<T>> transData = transpose(data); 110 111 // 找到当前切分点 112 size_t splitAttributeIndex = depth % k; 113 vector<T> splitAttributes = transData[splitAttributeIndex]; 114 T splitValue = findMiddleValue(splitAttributes); 115 116 vector<vector<T>> leftSubSet; 117 vector<vector<T>> rightSubset; 118 119 for (size_t i = 0; i < samplesNum; ++i) { 120 if (splitAttributes[i] == splitValue && tree->isEmpty()) { 121 tree->root = data[i]; 122 } else if (splitAttributes[i] < splitValue) { 123 leftSubSet.push_back(data[i]); 124 } else { 125 rightSubset.push_back(data[i]); 126 } 127 } 128 129 tree->leftChild = new KdTree<T>; 130 tree->leftChild->parent = tree; 131 tree->rightChild = new KdTree<T>; 132 tree->rightChild->parent = tree; 133 buildKdTree(tree->leftChild, leftSubSet, depth + 1); 134 buildKdTree(tree->rightChild, rightSubset, depth + 1); 135 } 136 137 /** 138 * 递归打印KdTree 139 * 140 * @param tree KdTree 141 * @param depth 当前深度 142 * 143 * @return void 144 */ 145 template<typename T> 146 void printKdTree(const KdTree<T> *tree, size_t depth) { 147 for (size_t i = 0; i < depth; ++i) { 148 cout << " "; 149 } 150 151 for (size_t i = 0; i < tree->root.size(); ++i) { 152 cout << tree->root[i] << " "; 153 } 154 cout << endl; 155 156 if (tree->leftChild == nullptr && tree->rightChild == nullptr) { 157 return; 158 } else { 159 if (tree->leftChild) { 160 for (int i = 0; i < depth + 1; ++i) { 161 cout << " "; 162 } 163 cout << "left : "; 164 printKdTree(tree->leftChild, depth + 1); 165 } 166 167 cout << endl; 168 169 if (tree->rightChild) { 170 for (size_t i = 0; i < depth + 1; ++i) { 171 cout << " "; 172 } 173 cout << "right : "; 174 printKdTree(tree->rightChild, depth + 1); 175 } 176 cout << endl; 177 } 178 } 179 180 /** 181 * 节点之间的欧氏距离 182 * 183 * @param p1 节点1 184 * @param p2 节点2 185 * 186 * @return 节点之间的欧式距离 187 */ 188 template<typename T> 189 T calDistance(const vector<T> &p1, const vector<T> &p2) { 190 T res = 0; 191 for (size_t i = 0; i < p1.size(); ++i) { 192 res += pow(p1[i] - p2[i], 2); 193 } 194 195 return res; 196 } 197 198 /** 199 * 搜索目标节点的最近邻 200 * 201 * @param tree KdTree 202 * @param goal 待分类的节点 203 * 204 * @return 最近邻节点 205 */ 206 template <typename T> 207 vector<T> searchNearestNeighbor(KdTree<T> *tree, const vector<T> &goal ) { 208 // 节点数属性个数 209 size_t k = tree->root.size(); 210 // 划分的索引 211 size_t d = 0; 212 KdTree<T> *currentTree = tree; 213 vector<T> currentNearest = currentTree->root; 214 // 找到目标节点的最叶节点 215 while (!currentTree->isLeaf()) { 216 size_t index = d % k; 217 if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) { 218 currentTree = currentTree->leftChild; 219 } else { 220 currentTree = currentTree->rightChild; 221 } 222 223 ++d; 224 } 225 currentNearest = currentTree->root; 226 T currentDistance = calDistance(goal, currentTree->root); 227 228 KdTree<T> *searchDistrict; 229 if (currentTree->isLeft()) { 230 if (!(currentTree->parent->rightChild)) { 231 searchDistrict = currentTree; 232 } else { 233 searchDistrict = currentTree->parent->rightChild; 234 } 235 } else { 236 searchDistrict = currentTree->parent->leftChild; 237 } 238 239 while (!(searchDistrict->parent)) { 240 T districtDistance = abs(goal[(d + 1) % k] - searchDistrict->parent->root[(d + 1) % k]); 241 242 if (districtDistance < currentDistance) { 243 T parentDistance = calDistance(goal, searchDistrict->parent->root); 244 245 if (parentDistance < currentDistance) { 246 currentDistance = parentDistance; 247 currentTree = searchDistrict->parent; 248 currentNearest = currentTree->root; 249 } 250 251 if (!searchDistrict->isEmpty()) { 252 T rootDistance = calDistance(goal, searchDistrict->root); 253 if (rootDistance < currentDistance) { 254 currentDistance = rootDistance; 255 currentTree = searchDistrict; 256 currentNearest = currentTree->root; 257 } 258 } 259 260 if (!(searchDistrict->leftChild)) { 261 T leftDistance = calDistance(goal, searchDistrict->leftChild->root); 262 if (leftDistance < currentDistance) { 263 currentDistance = leftDistance; 264 currentTree = searchDistrict; 265 currentNearest = currentTree->root; 266 } 267 } 268 269 if (!(searchDistrict->rightChild)) { 270 T rightDistance = calDistance(goal, searchDistrict->rightChild->root); 271 if (rightDistance < currentDistance) { 272 currentDistance = rightDistance; 273 currentTree = searchDistrict; 274 currentNearest = currentTree->root; 275 } 276 } 277 278 } 279 280 if (!(searchDistrict->parent->parent)) { 281 searchDistrict = searchDistrict->parent->isLeft()? searchDistrict->parent->parent->rightChild : searchDistrict->parent->parent->leftChild; 282 } else { 283 searchDistrict = searchDistrict->parent; 284 } 285 ++d; 286 } 287 288 return currentNearest; 289 } 290 291 int main(int argc, const char * argv[]) { 292 vector<vector<double>> trainDataSet{{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}}; 293 KdTree<double> *kdTree = new KdTree<double>; 294 buildKdTree(kdTree, trainDataSet, 0); 295 printKdTree(kdTree, 0); 296 297 vector<double> goal{3, 4.5}; 298 vector<double> nearestNeighbor = searchNearestNeighbor(kdTree, goal); 299 300 for (auto i : nearestNeighbor) { 301 cout << i << " "; 302 } 303 cout << endl; 304 305 return 0; 306 }