zoukankan      html  css  js  c++  java
  • KNN in c++

    Pseudo Code of KNN

    We can implement a KNN model by following the below steps:

    1. Load the data
    2. Initialise the value of k
    3. For getting the predicted class, iterate from 1 to total number of training data points
      1. Calculate the distance between test data and each row of training data. Here we will use Euclidean distance as our distance metric since it’s the most popular method. The other metrics that can be used are Chebyshev, cosine, etc.
      2. Sort the calculated distances in ascending order based on distance values
      3. Get top k rows from the sorted array
      4. Get the most frequent class of these rows
      5. Return the predicted class

    Iris Data Set

    把数据作为string类型处理,进行string和double类型转换。

    #include <iostream>
    #include <string>
    #include <fstream>
    #include <sstream>
    #include <numeric>
    #include <functional>
    #include <vector>
    #include <algorithm>
    #include <cmath>
    #include <map>

    template <class T1, class T2>
    double ManhattanDistance(std::vector<T1> &inst1, std::vector<T2> &inst2) {
        if(inst1.size() != inst2.size()) {
            std::cout<<"the size of the vectors is not the same ";
            return -1;
        }
        std::vector<double> temp;
        for(size_t i=0;i<inst1.size();++i) {
            temp.push_back(std::abs(inst1.at(i)-inst2.at(i)));
        }
        double distance=accumulate(temp.begin(), temp.end(), 0.0);

        return distance;
    }

    template <class DataType1, class DataType2>
    double EuclideanDistance(const std::vector<DataType1> &inst1, const std::vector<DataType2> &inst2) {
        if(inst1.size() != inst2.size()) {
        std::cout<<"the size of the vectors is not the same ";
            return -1;
        }
        std::vector<double> temp;
        for(size_t i=0; i<inst1.size(); ++i) {
            temp.push_back(pow(inst1.at(i)-inst2.at(i), 2.0));
        }
        double distance=accumulate(temp.begin(), temp.end(), 0.0);
        distance=sqrt(distance);

        return distance;
    }

    void vstr2vdouble(std::vector<std::string>::const_iterator beg, std::vector<std::string>::const_iterator end, std::vector<double> &vdouble) {
        for(std::vector<std::string>::const_iterator it=beg; it!=end; ++it) {
            double d;
            std::stringstream ss;
            ss<<*it;
            ss>>d;
            vdouble.push_back(d);
        }
    }

    void knn(std::vector<std::vector<std::string> > &trainset, std::vector<std::string> &testdata, int &k) {
        std::vector<double> testitem;
        vstr2vdouble(testdata.begin(), testdata.end(), testitem);
        std::multimap<std::string, std::string> mmap;

        for(size_t i=0;i<trainset.size();++i) {
            std::vector<double> trainitem;
            vstr2vdouble(trainset[i].begin(), trainset[i].end()-1, trainitem);
            double distance=EuclideanDistance(testitem, trainitem);
            std::string strdis;
            std::stringstream ss;
            ss<<distance;
            ss>>strdis;
            mmap.insert(std::pair<std::string, std::string>(strdis, trainset[i].back()));
        }
        size_t i=0;
        for(std::multimap<std::string, std::string>::const_iterator it=mmap.begin(); i<k; ++i,++it) {
            std::cout<<it->first<<" "<<it->second<<" ";
        }
    }

    template <class DataType>
    void ReadDataFromFile(std::string &filename, std::vector<std::vector<DataType> > &lines_feat) {
        std::ifstream vm_info(filename.c_str());
        std::string lines, var;
        std::vector<std::string> row;

        lines_feat.clear();

        while(!vm_info.eof()) {
            getline(vm_info, lines);
            if(lines.empty())
                break;
            std::replace(lines.begin(), lines.end(), ',', ' ');
            std::stringstream stringin(lines);
            row.clear();

            while(stringin >> var) {
                row.push_back(var);
            }
            lines_feat.push_back(row);
        }
    }

    template <class DataType>
    void Display2DVector(std::vector<std::vector<DataType> > &vv) {
        std::cout<<"the total rows of 2d vector_data: "<<vv.size()<<std::endl;

        for(size_t i=0;i<vv.size();++i) {
            for(typename::std::vector<DataType>::const_iterator it=vv[i].begin();it!=vv[i].end();++it) {
                std::cout<<*it<<" ";
            }
            std::cout<<" ";
        }
        std::cout<<"--------the end of the Display2DVector()-------- ";
    }

    int main() {
        std::string trainpath="Iris.data", testpath="knntest.data";
        std::vector<std::vector<std::string> > knn_data, test_data;

        ReadDataFromFile(trainpath, knn_data);
        ReadDataFromFile(testpath, test_data);

        Display2DVector(test_data);

        int k=3;
        for(size_t i=0;i<test_data.size();++i) {
            knn(knn_data, test_data[i], k);
        }

        return 0;
    }

  • 相关阅读:
    java.lang.IllegalAccessError: tried to access method org.apache.poi.util.POILogger.log from class org.apache.poi.openxml4j.opc.ZipPackage
    相同域名不同端口的两个应用,cookie名字、路径都相同的情况下,后面cookie会覆盖前面cookie吗
    power designer 连接mysql提示“connection test failed”
    疑问:Spring 中构造器、init-method、@PostConstruct、afterPropertiesSet 孰先孰后,自动注入发生时间
    intelj idea 创建聚合项目(典型web项目,包括子项目util、dao、service)
    Mysql启动时提示:Another MySQL daemon already running with the same unix socket.
    MySql中的varchar长度究竟是字节还是字符
    百度echarts使用--y轴label数字太长难以全部显示
    记录项目中用的laypage分页代码
    Ubuntu16.04下安装Cmake-3.8.2并为其配置环境变量
  • 原文地址:https://www.cnblogs.com/donggongdechen/p/10430993.html
Copyright © 2011-2022 走看看