zoukankan      html  css  js  c++  java
  • meanshift聚类的实现

    参见http://blog.csdn.net/u014568921/article/details/45197027

    // meanshift-cluster.cpp : 定义控制台应用程序的入口点。
    //
    
    #include "stdafx.h"
    #include<iostream>
    #include<vector>
    #include<assert.h>
    #include<cstdlib>
    #include<time.h>
    using namespace std;
    
    #define MSTYPE double
    
    class meanshift
    {
    private:
    	struct MSData
    	{
    		vector<MSTYPE>data;
    		//unsigned int dim;
    		MSData(unsigned int d)
    		{
    			//dim = d;
    			data.resize(d);
    		}
    	};
    	vector<MSData>dataset;
    	double kernel_bandwidth;
    
    	MSData shiftvec(MSData vec)
    	{
    		MSData shiftvector(vec.data.size());
    		
    		double total_weight = 0;
    		for (int i = 0; i<dataset.size(); i++){
    			MSData temp = dataset[i];
    			double distance = euclidean_distance(vec, temp);
    			double weight = gaussian_kernel(distance);
    			for (int j = 0; j<shiftvector.data.size(); j++){
    				shiftvector.data[j] += temp.data[j] * weight;
    			}
    			total_weight += weight;
    		}
    		for (int i = 0; i<shiftvector.data.size(); i++){
    			shiftvector.data[i] /= total_weight;
    		}
    		return shiftvector;
    	}
    	double gaussian_kernel(double distance){
    		double temp = exp(-(distance*distance) / (kernel_bandwidth));
    		return temp;
    	}
    	double euclidean_distance(const MSData &data1, const MSData &data2)
    	{
    		assert(data1.data.size() == data2.data.size());
    		double sum = 0;
    		for (int i = 0; i<data1.data.size(); i++){
    			sum += (data1.data[i] - data2.data[i]) * (data1.data[i] - data2.data[i]);
    		}
    		return sqrt(sum);
    	}
    
    
    public:
    	meanshift(double kernel_bandwidth) :kernel_bandwidth(kernel_bandwidth)
    	{
    		time_t t;
    		srand(time(&t));
    	}
    	vector<MSData> apply()
    	{
    		vector<int> stop_moving;
    		stop_moving.resize(dataset.size());
    		vector<MSData> shifted_points = dataset;
    		double max_shift_distance;
    		do {
    			max_shift_distance = 0;
    			for (int i = 0; i<shifted_points.size(); i++){
    				if (!stop_moving[i]) {
    					MSData point_new = shiftvec(shifted_points[i]);
    					double shift_distance = euclidean_distance(point_new, shifted_points[i]);
    					if (shift_distance > max_shift_distance){
    						max_shift_distance = shift_distance;
    					}
    #define EPSILON 0.00000001
    					if (shift_distance <= EPSILON) {
    						stop_moving[i] = 1;
    					}
    					shifted_points[i] = point_new;
    				}
    			}
    			printf("max_shift_distance: %f
    ", max_shift_distance);
    		} while (max_shift_distance > EPSILON);
    		
    		
    		for (int i = 0; i < dataset.size(); i++)
    		{
    			cout << "原始坐标 (" << dataset[i].data[0] << "," << dataset[i].data[1] << ")   滑动到  ("
    				<< shifted_points[i].data[0] << "," << shifted_points[i].data[1] << ")" << endl;
    		}
    		
    		return shifted_points;
    	}
    	
    	void generatedata(int datanums,vector<int>&span)
    	{
    		for (int i = 0; i < datanums; i++)
    		{
    			MSData dd(span.size());
    			for (int j = 0; j < span.size(); j++)
    			{
    				dd.data[j] = double(rand()) / (RAND_MAX + 1.0)*span[j];
    			}
    			dataset.push_back(dd);
    		}
    	}
    
    
    };
    
    
    int _tmain(int argc, _TCHAR* argv[])
    {
    	meanshift ms(4);
    	vector<int>span;
    	span.push_back(20);
    	span.push_back(20);
    	ms.generatedata(100, span);
    	ms.apply();
    
    
    
    	return 0;
    }
    


    结果如下图



    版权声明:

  • 相关阅读:
    方法引用(method reference)
    函数式接口
    Lambda 表达式
    LinkedList 源码分析
    ArrayList 源码分析
    Junit 学习笔记
    Idea 使用 Junit4 进行单元测试
    Java 定时器
    【干货】Mysql的"事件探查器"-之Mysql-Proxy代理实战一(安装部署与实战sql拦截与性能监控)
    python-flask框架web服务接口开发实例
  • 原文地址:https://www.cnblogs.com/walccott/p/4956867.html
Copyright © 2011-2022 走看看