zoukankan      html  css  js  c++  java
  • DBSCAN算法的Java,C++,Python实现

    最近由于要实现‘基于网格的DBSCAN算法’,网上有没有找到现成的代码[如果您有代码,麻烦联系我],只好参考已有的DBSCAN算法的实现。先从网上随便找了几篇放这儿,之后对比研究。

    DBSCAN简介:

    1.简介
      DBSCAN 算法是一种基于密度的空间聚类算法。该算法利用基于密度的聚类的概念,即要求聚类空间中的一定区域内所包含对象(点或其它空间对象)的数目不小于某一给定阀   值。DBSCAN 算法的显著优点是聚类速度快且能够有效处理噪声点和发现任意形状的空间聚类。但是由于它直接对整个数据库进行操作且进行聚类时使用了一个全局性的表征  密度的参数,因此也具有两个比较明显的弱点:

            1. 当数据量增大时,要求较大的内存支持 I/0 消耗也很大;

            2. 当空间聚类的密度不均匀、聚类间距离相差很大时,聚类质量较差。


    2.DBSCAN算法的聚类过程
      DBSCAN算法基于一个事实:一个聚类可以由其中的任何核心对象唯一确定。等价可以表述为: 任一满足核心对象条件的数据对象p,数据库D中所有从p密度可达的数据对象o  所组成的集合构成了一个完整的聚类C,且p属于C。


    3.DBSCAN中的几个定义
      密度可达是直接密度可达的传递闭包,非对称性关系;密度相连是对称性关系。DBSCA目的是找到密度相连对象的最大集合。

      E领域:给定对象p半径为E内的区域称为该对象的E领域;

      核心对象:p的E领域内样本数大于MinPts(算法输入值),则该对象p为核心对象;

      直接密度可达:对于样本集合D,如果样本点q在p的E领域内,且p为核心对象,则p直接密度可达q;

      密度可达:对于样本集合D,存在一串样本点p1,p2,p3,...pn,其中连续两个点直接密度可达,则 p=p1,q=qn,则p密度可达q;

      密度相连:对于样本集合D中任意一点o,存在p到o密度可达,并且q到o密度可达,那么q从p密度相连;


    算法伪代码

     1 DBSCAN(SetOfPoints, Eps, MinPts){
     2     ClusterId=nextId(NOISE)
     3     for(i=0;i<SetOfPoints.size();i++){
     4         point=SetOfPoints.get(i)
     5         if (point.ClId==UNCLASSIFIED){
     6             If(ExpandCluster(SetOfPoints,point, ClusterId, Eps, MinPts)){
     7                 ClusterId=nextId(ClusterId)
     8             }
     9         }
    10     }
    11 }
    12     
    13 ExpandCluster(SetOfPoints,Point, ClId, Eps, MinPts){
    14     seeds=SetOfPoints.regionQuery(Point, Eps)
    15     if(seeds.size()<MinPts){
    16         SetOfPoints.changeClId(Point,NOISE)
    17         return False
    18     }else{
    19         SetOfPoints.changeClIds(seeds,ClId)
    20         seeds.delete(Point)
    21         while(seeds.size()>0){
    22             currentP=seeds.first()
    23             result=SetOfPoints.regionQuery(currentP, Eps)
    24             if(result.size()>= MinPts){
    25                 for(i=0;i<result.size();i++){
    26                     resultP=result.get(i)
    27                     if(resultP.ClId ==UNCLASSIFIED or resultP.ClId ==NOISE){
    28                         if(resultP.ClId ==UNCLASSIFIED){
    29                             seeds.append(resultP)
    30                         }
    31                         SetOfPoints.changeClId(resultP,ClId)
    32                     }
    33                 }
    34             }
    35             seeds.delete(currentP)
    36         }
    37         return True
    38     }
    39 }
    View Code

    JAVA实现:

     1 package orisun;
     2  
     3 import java.io.File;
     4 import java.util.ArrayList;
     5 import java.util.Vector;
     6 import java.util.Iterator;
     7  
     8 public class DBScan {
     9  
    10     double Eps=3;   //区域半径
    11     int MinPts=4;   //密度
    12      
    13     //由于自己到自己的距离是0,所以自己也是自己的neighbor
    14     public Vector<DataObject> getNeighbors(DataObject p,ArrayList<DataObject> objects){
    15         Vector<DataObject> neighbors=new Vector<DataObject>();
    16         Iterator<DataObject> iter=objects.iterator();
    17         while(iter.hasNext()){
    18             DataObject q=iter.next();
    19             double[] arr1=p.getVector();
    20             double[] arr2=q.getVector();
    21             int len=arr1.length;
    22              
    23             if(Global.calEditDist(arr1,arr2,len)<=Eps){      //使用编辑距离
    24 //          if(Global.calEuraDist(arr1, arr2, len)<=Eps){    //使用欧氏距离    
    25 //          if(Global.calCityBlockDist(arr1, arr2, len)<=Eps){   //使用街区距离
    26 //          if(Global.calSinDist(arr1, arr2, len)<=Eps){ //使用向量夹角的正弦
    27                 neighbors.add(q);
    28             }
    29         }
    30         return neighbors;
    31     }
    32      
    33     public int dbscan(ArrayList<DataObject> objects){
    34         int clusterID=0;
    35         boolean AllVisited=false;
    36         while(!AllVisited){
    37             Iterator<DataObject> iter=objects.iterator();
    38             while(iter.hasNext()){
    39                 DataObject p=iter.next();
    40                 if(p.isVisited())
    41                     continue;
    42                 AllVisited=false;
    43                 p.setVisited(true);     //设为visited后就已经确定了它是核心点还是边界点
    44                 Vector<DataObject> neighbors=getNeighbors(p,objects);
    45                 if(neighbors.size()<MinPts){
    46                     if(p.getCid()<=0)
    47                         p.setCid(-1);       //cid初始为0,表示未分类;分类后设置为一个正数;设置为-1表示噪声。
    48                 }else{
    49                     if(p.getCid()<=0){
    50                         clusterID++;
    51                         expandCluster(p,neighbors,clusterID,objects);
    52                     }else{
    53                         int iid=p.getCid();
    54                         expandCluster(p,neighbors,iid,objects);
    55                     }
    56                 }
    57                 AllVisited=true;
    58             }
    59         }
    60         return clusterID;
    61     }
    62  
    63     private void expandCluster(DataObject p, Vector<DataObject> neighbors,
    64             int clusterID,ArrayList<DataObject> objects) {
    65         p.setCid(clusterID);
    66         Iterator<DataObject> iter=neighbors.iterator();
    67         while(iter.hasNext()){
    68             DataObject q=iter.next();
    69             if(!q.isVisited()){
    70                 q.setVisited(true);
    71                 Vector<DataObject> qneighbors=getNeighbors(q,objects);
    72                 if(qneighbors.size()>=MinPts){
    73                     Iterator<DataObject> it=qneighbors.iterator();
    74                     while(it.hasNext()){
    75                         DataObject no=it.next();
    76                         if(no.getCid()<=0)
    77                             no.setCid(clusterID);
    78                     }
    79                 }
    80             }
    81             if(q.getCid()<=0){       //q不是任何簇的成员
    82                 q.setCid(clusterID);
    83             }
    84         }
    85     }
    86  
    87     public static void main(String[] args){
    88         DataSource datasource=new DataSource();
    89         //Eps=3,MinPts=4
    90         datasource.readMatrix(new File("/home/orisun/test/dot.mat"));
    91         datasource.readRLabel(new File("/home/orisun/test/dot.rlabel"));
    92         //Eps=2.5,MinPts=4
    93 //      datasource.readMatrix(new File("/home/orisun/text.normalized.mat"));
    94 //      datasource.readRLabel(new File("/home/orisun/text.rlabel"));
    95         DBScan ds=new DBScan();
    96         int clunum=ds.dbscan(datasource.objects);
    97         datasource.printResult(datasource.objects,clunum);
    98     }
    99 }
    View Code

    C++实现:

    数据结构

     1 #include <vector>
     2 
     3 using namespace std;
     4 
     5 const int DIME_NUM=2;        //数据维度为2,全局常量
     6 
     7 //数据点类型
     8 class DataPoint
     9 {
    10 private:
    11     unsigned long dpID;                //数据点ID
    12     double dimension[DIME_NUM];        //维度数据
    13     long clusterId;                    //所属聚类ID
    14     bool isKey;                        //是否核心对象
    15     bool visited;                    //是否已访问
    16     vector<unsigned long> arrivalPoints;    //领域数据点id列表
    17 public:
    18     DataPoint();                                                    //默认构造函数
    19     DataPoint(unsigned long dpID,double* dimension , bool isKey);    //构造函数
    20 
    21     unsigned long GetDpId();                //GetDpId方法
    22     void SetDpId(unsigned long dpID);        //SetDpId方法
    23     double* GetDimension();                    //GetDimension方法
    24     void SetDimension(double* dimension);    //SetDimension方法
    25     bool IsKey();                            //GetIsKey方法
    26     void SetKey(bool isKey);                //SetKey方法
    27     bool isVisited();                        //GetIsVisited方法
    28     void SetVisited(bool visited);            //SetIsVisited方法
    29     long GetClusterId();                    //GetClusterId方法
    30     void SetClusterId(long classId);        //SetClusterId方法
    31     vector<unsigned long>& GetArrivalPoints();    //GetArrivalPoints方法
    32 };
    View Code

    实现

     1 #include "DataPoint.h"
     2 
     3 //默认构造函数
     4 DataPoint::DataPoint()
     5 {
     6 }
     7 
     8 //构造函数
     9 DataPoint::DataPoint(unsigned long dpID,double* dimension , bool isKey):isKey(isKey),dpID(dpID)
    10 {
    11     //传递每维的维度数据
    12     for(int i=0; i<DIME_NUM;i++)
    13     {
    14         this->dimension[i]=dimension[i];
    15     }
    16 }
    17 
    18 //设置维度数据
    19 void DataPoint::SetDimension(double* dimension)
    20 {
    21     for(int i=0; i<DIME_NUM;i++)
    22     {
    23         this->dimension[i]=dimension[i];
    24     }
    25 }
    26 
    27 //获取维度数据
    28 double* DataPoint::GetDimension()
    29 {
    30     return this->dimension;
    31 }
    32 
    33 //获取是否为核心对象
    34 bool DataPoint::IsKey()
    35 {
    36     return this->isKey;
    37 }
    38 
    39 //设置核心对象标志
    40 void DataPoint::SetKey(bool isKey)
    41 {
    42     this->isKey = isKey;
    43 }
    44 
    45 //获取DpId方法
    46 unsigned long DataPoint::GetDpId()
    47 {
    48     return this->dpID;
    49 }
    50 
    51 //设置DpId方法
    52 void DataPoint::SetDpId(unsigned long dpID)
    53 {
    54     this->dpID = dpID;
    55 }
    56 
    57 //GetIsVisited方法
    58 bool DataPoint::isVisited()
    59 {
    60     return this->visited;
    61 }
    62 
    63 
    64 //SetIsVisited方法
    65 void DataPoint::SetVisited( bool visited )
    66 {
    67     this->visited = visited;
    68 }
    69 
    70 //GetClusterId方法
    71 long DataPoint::GetClusterId()
    72 {
    73     return this->clusterId;
    74 }
    75 
    76 //GetClusterId方法
    77 void DataPoint::SetClusterId( long clusterId )
    78 {
    79     this->clusterId = clusterId;
    80 }
    81 
    82 //GetArrivalPoints方法
    83 vector<unsigned long>& DataPoint::GetArrivalPoints()
    84 {
    85     return arrivalPoints;
    86 }
    View Code

    PYTHON实现:

     1 from matplotlib.pyplot import *  
     2  from collections import defaultdict  
     3  import random  
     4    
     5  #function to calculate distance  
     6  def dist(p1, p2):  
     7    return ((p1[0]-p2[0])**2+ (p1[1]-p2[1])**2)**(0.5)  
     8    
     9  #randomly generate around 100 cartesian coordinates  
    10  all_points=[]  
    11    
    12  for i in range(100):  
    13    randCoord = [random.randint(1,50), random.randint(1,50)]  
    14    if not randCoord in all_points:  
    15      all_points.append(randCoord)  
    16    
    17    
    18  #take radius = 8 and min. points = 8  
    19  E = 8  
    20  minPts = 8  
    21    
    22  #find out the core points  
    23  other_points =[]  
    24  core_points=[]  
    25  plotted_points=[]  
    26  for point in all_points:  
    27    point.append(0) # assign initial level 0  
    28    total = 0  
    29    for otherPoint in all_points:  
    30      distance = dist(otherPoint,point)  
    31      if distance<=E:  
    32        total+=1  
    33    
    34    if total > minPts:  
    35      core_points.append(point)  
    36      plotted_points.append(point)  
    37    else:  
    38      other_points.append(point)  
    39    
    40  #find border points  
    41  border_points=[]  
    42  for core in core_points:  
    43    for other in other_points:  
    44      if dist(core,other)<=E:  
    45        border_points.append(other)  
    46        plotted_points.append(other)  
    47    
    48    
    49  #implement the algorithm  
    50  cluster_label=0  
    51    
    52  for point in core_points:  
    53    if point[2]==0:  
    54      cluster_label+=1  
    55      point[2]=cluster_label  
    56    
    57    for point2 in plotted_points:  
    58      distance = dist(point2,point)  
    59      if point2[2] ==0 and distance<=E:  
    60        print point, point2  
    61        point2[2] =point[2]  
    62    
    63    
    64  #after the points are asssigned correnponding labels, we group them  
    65  cluster_list = defaultdict(lambda: [[],[]])  
    66  for point in plotted_points:  
    67    cluster_list[point[2]][0].append(point[0])  
    68    cluster_list[point[2]][1].append(point[1])  
    69    
    70  markers = ['+','*','.','d','^','v','>','<','p']  
    71    
    72  #plotting the clusters  
    73  i=0  
    74  print cluster_list  
    75  for value in cluster_list:  
    76    cluster= cluster_list[value]  
    77    plot(cluster[0], cluster[1],markers[i])  
    78    i = i%10+1  
    79    
    80  #plot the noise points as well  
    81  noise_points=[]  
    82  for point in all_points:  
    83    if not point in core_points and not point in border_points:  
    84      noise_points.append(point)  
    85  noisex=[]  
    86  noisey=[]  
    87  for point in noise_points:  
    88    noisex.append(point[0])  
    89    noisey.append(point[1])  
    90  plot(noisex, noisey, "x")  
    91    
    92  title(str(len(cluster_list))+" clusters created with E ="+str(E)+" Min Points="+str(minPts)+" total points="+str(len(all_points))+" noise Points = "+ str(len(noise_points)))  
    93  axis((0,60,0,60))  
    94  show()  
    View Code

    参考:http://www.cnblogs.com/zhangchaoyang/articles/2182748.html

       http://www.cnblogs.com/lovell-liu/archive/2011/11/08/2241542.html

         http://blog.sudipk.com.np/2013/02/implementation-of-dbscan-algorithm-for.html

         http://caoyaqiang.diandian.com/post/2012-09-26/40039517485

  • 相关阅读:
    645. 错误的集合『简单』
    1078. Bigram 分词『简单』
    1018. 可被 5 整除的二进制前缀『简单』
    1010. 总持续时间可被 60 整除的歌曲『简单』
    1417. 重新格式化字符串『简单』
    1413. 逐步求和得到正数的最小值『简单』
    1394. 找出数组中的幸运数『简单』
    1374. 生成每种字符都是奇数个的字符串『简单』
    1365. 有多少小于当前数字的数字『简单』
    1360. 日期之间隔几天『简单』
  • 原文地址:https://www.cnblogs.com/sungyouyu/p/3636708.html
Copyright © 2011-2022 走看看