zoukankan      html  css  js  c++  java
  • k_means算法C++实现,改为面向对象

    画的类图如下:

    分为两个类kMeans和pointClass类,kMeans两个成员变量:节点总个数和最终聚类个数。pointClass类包含结构体point。

    具体代码如下:

    kMeans.h

    class kMeans{
    protected:
        int numOfPoint, numOfCenter;
    public:
        kMeans();
        void k_means();
    };
    View Code

    kMeans.cpp

    #include "kMeans.h"
    #include "point.h"
    #include <iostream>
    using namespace std;
    
    kMeans::kMeans()
    {
        cout << "分别输入数据点和最终聚类的个数: ";
        cin >> numOfPoint >> numOfCenter;
    }
    
    void kMeans::k_means()
    {
        pointClass pc(numOfPoint, numOfCenter);
        pc.InitCenter(numOfPoint, numOfCenter);
        bool b = true;
        while (b){
            pc.setPoint(numOfPoint, numOfCenter);
            if (pc.getError()){
                break;
            }
            pc.getNewCenter(numOfPoint, numOfCenter);
            b = pc.IsEnd(numOfCenter);
            pc.resetCenterOld(numOfCenter);
        }
        if (b) {
            cout << "聚类操作无法完成!!" << endl;
        }else{
            pc.ExportData(numOfPoint, numOfCenter);
        }
    }
    View Code

    point.h

    struct point{
        double x1, x2, x3, x4;
        int flag;
    };
    
    class pointClass{
    protected:
        double errorOld, errorNew;
        point *pList, *centerListNew, *centerListOld;
        double GetDistance(point point1, point point2);
        bool GetExist(int xm, int centerList[], int n);
    public:
        pointClass (int numOfPoint, int numOfCenter);
        void InitCenter (int numOfPoint, int numOfCenter);
        void setPoint (int numOfPoint, int numOfCenter);
        bool getError();
        void getNewCenter (int numOfPoint, int numOfCenter);
        bool IsEnd (int numOfCenter);
        void ExportData (int numOfPoint, int numOfCenter);
        void resetCenterOld (int numOfCenter);
        ~pointClass ();
    };
    View Code

    point.cpp

    #include <iostream>
    #include <fstream>
    #include <cmath>
    #include <ctime>
    #include "point.h"
    using namespace std;
    
    double pointClass::GetDistance(point point1, point point2)
    {
        return pow(pow(point1.x1 - point2.x1, 2) + pow(point1.x2 - point2.x2, 2) + pow(point1.x3 - point2.x3, 2) + pow(point1.x4 - point2.x4, 2), 0.5);
    }
    
    bool pointClass::GetExist(int xm, int CentIndex[], int n)
    {
        bool b = false;
        for (int i = 0; i < n; i++){
            if (xm == CentIndex[i]){
                b = true;
                break;
            }
        }
        return b;
    }
    
    pointClass::pointClass(int numOfPoint, int numOfCenter)
    {
        pList = new point[numOfPoint];
        centerListOld = new point[numOfCenter];
        centerListNew = new point[numOfCenter];
        ifstream ifile("D:\IrisData.txt");
        if (!ifile.is_open()){
            cerr << "file" << endl;
            exit(0);
        }
        int i = 0;
        while (i < numOfPoint){
            ifile >> pList[i].x1 >> pList[i].x2 >> pList[i].x3 >> pList[i].x4;
            pList[i].flag = 0;
            i++;
        }
        ifile.close();
    
        errorNew = 0, errorOld = 0;
        for (i = 0; i < numOfCenter; i++){
            centerListOld[i].x1 = 0;
            centerListOld[i].x2 = 0;
            centerListOld[i].x3 = 0;
            centerListOld[i].x4 = 0;
            centerListNew[i].x1 = 0;
            centerListNew[i].x2 = 0;
            centerListNew[i].x3 = 0;
            centerListNew[i].x4 = 0;
            centerListNew[i].flag = 0;
            centerListOld[i].flag = 0;
        }
    }
    
    void pointClass::InitCenter(int numOfPoint, int numOfCenter)
    {
        int xm, i;
        int *CenterIndex = new int[numOfCenter];
        srand((unsigned)time(0));
        for (i = 0; i < numOfCenter; i++){
            do {
                xm = rand() % numOfPoint;
            } while (GetExist(xm, CenterIndex, i));
            CenterIndex[i] = xm;
        }
    
        for (i = 0; i < numOfCenter; i++){
            centerListOld[i] = pList[CenterIndex[i]];
        }
    }
    
    
    void pointClass::setPoint(int numOfPoint, int numOfCenter)
    {
        errorNew = 0;
        for (int i = 0; i < numOfPoint; i++){
            int flagi = 0;
            double distance = GetDistance(pList[i], centerListOld[0]);
            for (int j = 1; j < numOfCenter; j++){
                double tmp = GetDistance(pList[i], centerListOld[j]);
                if (tmp < distance){
                    tmp = distance;
                    flagi = j;
                }
            }
            pList[i].flag = flagi;
            errorNew = GetDistance(pList[i], centerListOld[flagi]);
        }
    }
    
    bool pointClass::getError()
    {
        bool b = false;
        if (errorOld != 0 && errorNew >= errorOld){
            b = true;
        }
        return b;
    }
    
    void pointClass::getNewCenter(int numOfPoint, int numOfCenter)
    {
        for (int i = 0; i < numOfCenter; i++){
            centerListNew[i].x1 = 0;
            centerListNew[i].x2 = 0;
            centerListNew[i].x3 = 0;
            centerListNew[i].x4 = 0;
            centerListNew[i].flag = 0;
        }
        for (int i = 0; i < numOfPoint; i++){
            centerListNew[pList[i].flag].x1 += pList[i].x1;
            centerListNew[pList[i].flag].x2 += pList[i].x2;
            centerListNew[pList[i].flag].x3 += pList[i].x3;
            centerListNew[pList[i].flag].x4 += pList[i].x4;
            centerListNew[pList[i].flag].flag++;
        }
        for (int i = 0; i < numOfCenter; i++){
            centerListNew[i].x1 = centerListNew[i].x1 / centerListNew[i].flag;
            centerListNew[i].x2 = centerListNew[i].x2 / centerListNew[i].flag;
            centerListNew[i].x3 = centerListNew[i].x3 / centerListNew[i].flag;
            centerListNew[i].x4 = centerListNew[i].x4 / centerListNew[i].flag;
            centerListNew[i].flag = 0;
        }
    }
    
    void pointClass::resetCenterOld(int numOfCenter)
    {
        for (int i = 0; i < numOfCenter; i++){
            centerListOld[i] = centerListNew[i];
        }
    }
    
    bool pointClass::IsEnd(int numOfCenter)
    {
        bool b = false;
        for (int i = 0; i < numOfCenter; i++){
            if (GetDistance(centerListNew[i],centerListOld[i]) > 1){
                b = true;
                break;
            }
        }
        return b;
    }
    
    void pointClass::ExportData(int numOfPoint, int numOfCenter)
    {
        ofstream ofile("D:\kMeansResult.txt");
        cout << "本次误差是:" << errorNew << endl;
        ofile << "本次误差是:" << errorNew << endl;
        for (int j = 0; j < numOfCenter; j++){
            ofile << "" << j+1 <<"类:" << endl;
            for (int i = 0; i < numOfPoint; i++){
                if (pList[i].flag == j){
                    ofile << pList[i].x1 << " " << pList[i].x2 << " " << pList[i].x3 << " " <<pList[i].x4 << endl;
                }
            }
        }
    }
    
    
    pointClass::~pointClass()
    {
        delete[] pList;
        delete[] centerListOld;
        delete[] centerListNew;
    }
    View Code

    主函数main.cpp

    #include <iostream>
    #include "kMeans.h"
    using namespace std;
    
    int main()
    {
        kMeans km;
        km.k_means();
        return 0;
    }
    View Code
  • 相关阅读:
    install git on ubuntu
    deploy uwsgi with niginx on ubuntu
    ubuntu下部署solr
    solr relevent project
    20100722
    [Programming Visual C++]Ex05cCScrollView Revisited
    iter_test
    交友类节目
    install scrapy on windows
    20100703
  • 原文地址:https://www.cnblogs.com/usa007lhy/p/3316985.html
Copyright © 2011-2022 走看看