zoukankan      html  css  js  c++  java
  • kdTree实践

    一、kdTree 数据结构节点

    • left:  左子树
    • right:右子树
    • fea:所选轴(特征)
    • dataNode:所选轴中点的样本

    二、kdTree实现主要包括两部分:

    • 1、建树  :计算轴方差,选出方差最大的轴,进行递归二分
    • 2、查询:根据当前kdTree节点轴的值与要查询节点轴的值比较,选择向左子树(或右子树)递归查询,得到两点间左子树(或右子树)的最小距离dis;根据当前kdTree节点轴的值与要查询节点轴的差值作比较,若差值较大,则说明(超球面是否与超矩形交割)要对右子树(或左子树)回溯

    三、代码实现

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Sun Sep 30 12:44:51 2018
     4 
     5 @author: Administrator
     6 """
     7 import pandas as pd
     8 import numpy as np
     9 import math
    10 #定义treeNode
    11 class Node:
    12     def __init__(self,lTree,rTree,fea,dataNode):  #fea表示选择的轴,dataNode 以该节点进行分割左右子树
    13         self.left=lTree;
    14         self.right=rTree;
    15         self.fea=fea;
    16         self.dataNode=dataNode                #标签包含在其中、
    17 
    18 
    19 ##直接用 DataFrame 作为数据结构
    20 def getInfo():
    21     data=[[2,3,''],[5,4,''],[9,6,''],[4,7,''],[8,1,''],[7,2,'']];              
    22     data=pd.DataFrame(data,columns=['fea1','fea2','label'])
    23     return data;
    24 
    25 # 计算方差,选择轴 根据轴方差
    26 def calSq(data):
    27     sq=data.var();        
    28     pos=data.columns[0];
    29     val=sq[0];
    30     for i in data.columns[1:-1]:   #选择方差最大的
    31         if(val<sq[i]):
    32             val=sq[i];
    33             pos=i;
    34     return pos;
    35 
    36  #按轴将数据拆分
    37 def splitAxis(data):  
    38     fea=calSq(data);
    39     sortData=data.sort_values(by=fea);   #按轴排序
    40     sortData=(np.array(sortData)).tolist();  #转list
    41     dataNode=pd.DataFrame( [ sortData[len(sortData)//2] ],    columns=list(data.columns));        #数据节点
    42     leftSet=pd.DataFrame( sortData[0:len(sortData)//2] , columns=list(data.columns) );    #左子树
    43     rightSet=pd.DataFrame(sortData[len(sortData)//2+1:] , columns=list(data.columns) );                  #右子树
    44     return fea,dataNode,leftSet,rightSet;
    45 
    46 #建树
    47 def createTree(data):   #递归建树
    48     if(len(data)>0):          #如果有数据
    49         fea,dataNode,leftSet,rightSet=splitAxis(data)
    50         treeNode=Node(None,None,fea,dataNode);
    51         if(len(leftSet)>0):            #左边是否可分
    52             treeNode.left=createTree(leftSet);
    53         if(len(rightSet)>0):         #右边是否可分
    54             treeNode.right=createTree(rightSet);
    55         return treeNode;
    56   
    57 #递归搜索      
    58 def search(tree,preNode):    #perNode 表示要查询一个样本;
    59     dis=0;
    60     for i in tree.dataNode.columns[:-1]:         #计算距离
    61         dis=dis+( tree.dataNode[i][0]-preNode[i][0] )**2;
    62     dis=math.sqrt(dis);
    63     label=tree.dataNode[tree.dataNode.columns[-1]][0];  #当前节点标记
    64     labelL='';
    65     labelR='';
    66     if(tree.left!=None and preNode[tree.fea][0] < tree.dataNode[tree.fea][0] ): #左边搜索
    67         disL,labelL = search( tree.left, preNode );
    68         if(disL<dis):                                                           #取距离最小的
    69             dis=disL
    70             label=labelL;
    71         if( dis >  abs(preNode[tree.fea][0] - tree.dataNode[tree.fea][0])): #超球面是否与超矩形交割 判断是否要回溯
    72             disHR,labelHR=search(tree.right,preNode);                                     #回溯右子树
    73             if(disHR<dis):
    74                 return disHR,labelHR
    75             else:
    76                 return dis,label
    77         
    78     if(tree.right!=None and preNode[tree.fea][0] >= tree.dataNode[tree.fea][0] ): #右边搜索
    79         disR,labelR=search(tree.right,preNode);
    80         if(disR < dis):                                                          #取距离最小的
    81             dis=disR;
    82             label=labelR;
    83         if( dis >  abs(preNode[tree.fea][0] - tree.dataNode[tree.fea][0])):  #超球面是否与超矩形交割 判断是否要回溯
    84             disHL,labelHL=search(tree.left,preNode);                        #回溯左子树
    85             if(disHL<dis):
    86                 return disHL,labelHL
    87             else:
    88                 return dis,label
    89     return dis,label;
    90     
    91 data=getInfo();
    92 root=createTree(data);
    93 test=pd.DataFrame( [ [7.1,1] ], columns=list(data.columns[:-1]));
    94 dis,label=search(root,test)
  • 相关阅读:
    缓存
    Java缓存
    数据库事务
    spring 事务管理
    MySQL错误解决10038
    mysql存储过程
    ECS修改默认端口22及限制root登录
    xunsearch安装配置
    https和http共存的nginx配置
    ECS 安装redis 及安装PHPredis的扩展
  • 原文地址:https://www.cnblogs.com/z-bear/p/9743382.html
Copyright © 2011-2022 走看看