zoukankan      html  css  js  c++  java
  • KNN算法的實現

    转自 小橋流水

    Knn.h

    #pragma once

    class Knn
    {
    private:
     double** trainingDataset;
     double* arithmeticMean;
     double* standardDeviation;
     int m, n;

     void RescaleDistance(double* row);
     void RescaleTrainingDataset();
     void ComputeArithmeticMean();
     void ComputeStandardDeviation();

     double Distance(double* x, double* y);
    public:
     Knn(double** trainingDataset, int m, int n);
     ~Knn();
     double Vote(double* test, int k);
    };

    Knn.cpp

    #include "Knn.h"
    #include <cmath>
    #include <map>

    using namespace std;

    Knn::Knn(double** trainingDataset, int m, int n)
    {
     this->trainingDataset = trainingDataset;
     this->m = m;
     this->n = n;
     ComputeArithmeticMean();
     ComputeStandardDeviation();
     RescaleTrainingDataset();
    }

    void Knn::ComputeArithmeticMean()
    {
     arithmeticMean = new double[n - 1];

     double sum;

     for(int i = 0; i < n - 1; i++)
     {
      sum = 0;
      for(int j = 0; j < m; j++)
      {
       sum += trainingDataset[j][i];
      }

      arithmeticMean[i] = sum / n;
     }
    }

    void Knn::ComputeStandardDeviation()
    {
     standardDeviation = new double[n - 1];

     double sum, temp;

     for(int i = 0; i < n - 1; i++)
     {
      sum = 0;
      for(int j = 0; j < m; j++)
      {
       temp = trainingDataset[j][i] - arithmeticMean[i];
       sum += temp * temp;
      }

      standardDeviation[i] = sqrt(sum / n);
     }
    }

    void Knn::RescaleDistance(double* row)
    {
     for(int i = 0; i < n - 1; i++)
     {
      row[i] = (row[i] - arithmeticMean[i]) / standardDeviation[i];
     }
    }

    void Knn::RescaleTrainingDataset()
    {
     for(int i = 0; i < m; i++)
     {
      RescaleDistance(trainingDataset[i]);
     }
    }

    Knn::~Knn()
    {
     delete[] arithmeticMean;
     delete[] standardDeviation;
    }

    double Knn::Distance(double* x, double* y)
    {
     double sum = 0, temp;
     for(int i = 0; i < n - 1; i++)
     {
      temp = (x[i] - y[i]);
      sum += temp * temp;
     }

     return sqrt(sum);
    }

    double Knn::Vote(double* test, int k)
    {
     RescaleDistance(test);

     double distance;

     map<int, double>::iterator max;

     map<int, double> mins;

     for(int i = 0; i < m; i++)
     {
      distance = Distance(test, trainingDataset[i]);
      if(mins.size() < k)
       mins.insert(map<int, double>::value_type(i, distance));
      else
      {
       max = mins.begin();
       for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
       {
        if(it->second > max->second)
         max = it;
       }
       if(distance < max->second)
       {
        mins.erase(max);
        mins.insert(map<int, double>::value_type(i, distance));
       }
      }
     }

     map<double, int> votes;
     double temp;

     for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
     {
      temp = trainingDataset[it->first][n-1];
      map<double, int>::iterator voteIt = votes.find(temp);
      if(voteIt != votes.end())
       voteIt->second ++;
      else
       votes.insert(map<double, int>::value_type(temp, 1));
     }

     map<double, int>::iterator maxVote = votes.begin();

     for(map<double, int>::iterator it = votes.begin(); it != votes.end(); it++)
     {
      if(it->second > maxVote->second)
       maxVote = it;
     }

     test[n-1] = maxVote->first;

     return maxVote->first;
    }

    main.cpp

    #include <iostream>
    #include "Knn.h"

    using namespace std;

    int main(const int& argc, const char* argv[])
    {
     double** train = new double* [14];
     for(int i = 0; i < 14; i ++)
      train[i] = new double[5];
     double trainArray[14][5] = 
     {
      {0, 0, 0, 0, 0},
      {0, 0, 0, 1, 0},
      {1, 0, 0, 0, 1},
      {2, 1, 0, 0, 1},
      {2, 2, 1, 0, 1},
      {2, 2, 1, 1, 0},
      {1, 2, 1, 1, 1},
      {0, 1, 0, 0, 0},
      {0, 2, 1, 0, 1},
      {2, 1, 1, 0, 1},
      {0, 1, 1, 1, 1},
      {1, 1, 0, 1, 1},
      {1, 0, 1, 0, 1},
      {2, 1, 0, 1, 0}
     };

     for(int i = 0; i < 14; i ++)
      for(int j = 0; j < 5; j ++)
       train[i][j] = trainArray[i][j];

     Knn knn(train, 14, 5);

     double test[5] = {2, 2, 0, 1, 0};
     cout<<knn.Vote(test, 3)<<endl;

     for(int i = 0; i < 14; i ++)
      delete[] train[i];

     delete[] train;

     return 0;
    }



  • 相关阅读:
    HDU1875——畅通工程再续(最小生成树:Kruskal算法)
    CodeForces114E——Double Happiness(素数二次筛选)
    POJ3083——Children of the Candy Corn(DFS+BFS)
    POJ3687——Labeling Balls(反向建图+拓扑排序)
    SDUT2157——Greatest Number(STL二分查找)
    UVA548——Tree(中后序建树+DFS)
    HDU1312——Red and Black(DFS)
    生活碎碎念
    SQL基础四(例子)
    Linux系统中的一些重要的目录
  • 原文地址:https://www.cnblogs.com/xiangshancuizhu/p/2229195.html
Copyright © 2011-2022 走看看