zoukankan      html  css  js  c++  java
  • 手写识别——KNN

    #include <iostream>
    #include<map>  
    #include<vector>  
    #include<stdio.h>  
    #include<cmath>  
    #include<cstdlib>  
    #include<algorithm>  
    #include<fstream>
     
    using namespace std;
     
    typedef char tLabel;  
    typedef double tData;  
    typedef pair<int,double>  PAIR;  
    const int colLen = 2;//导入新的数据集时只需要修改行列参数  
    const int rowLen = 6;  
    ifstream fin;  
    ofstream fout;
     
    class KNN
    {
    private:
            tData dataSet[rowLen][colLen];    //用数组定义样本集
            tLabel labels[rowLen];
            tData testData[colLen];  
            int k;  
            map<int,double> map_index_dis;  
            map<tLabel,int> map_label_freq;  
            double get_distance(tData *d1,tData *d2);    //计算两两样本间距离函数
    public:
            KNN(int k);  //构造函数
      
            void get_all_distance();  
              
            void get_max_freq_label();  
      
            struct CmpByValue  
            {  
                bool operator() (const PAIR& lhs,const PAIR& rhs)  
                {  
                    return lhs.second < rhs.second;  
                }  
            };      
    };
     
    KNN::KNN(int k)
    {
        this->k = k;
        fin.open("movie_data.txt");//导入新的数据集时只需修改文件名
        if(!fin)
        {
            cout<<"can not open the file data.txt"<<endl;  
            exit(1);    
        }
        for(int i = 0; i < rowLen; i++)
        {
            for(int j = 0;j <colLen; j++)
            {
                fin>>dataSet[i][j];
            }
            fin>>labels[i];
        }
        
        cout<<"please input the test data :"<<endl;  
        //输入测试数据   
        for(int i=0;i<colLen;i++)  
            cin>>testData[i];
    }
     
    double KNN:: get_distance(tData *d1,tData *d2)  
    {  
        double sum = 0;  
        for(int i=0;i<colLen;i++)  
        {  
            sum += pow( (d1[i]-d2[i]) , 2 );  
        }  
      
    //  cout<<"the sum is = "<<sum<<endl;  
        return sqrt(sum);  
    }
     
    //计算测试样本与训练集中每个样本的距离   
    void KNN:: get_all_distance()  
    {  
        double distance;  
        int i;  
        for(i=0;i<rowLen;i++)  
        {  
            distance = get_distance(dataSet[i],testData);  
            //<key,value> => <i,distance>  
            map_index_dis[i] = distance;  
        }  
        //遍历map,打印各个序号和距离
        map<int,double>::const_iterator it = map_index_dis.begin();  
        while(it!=map_index_dis.end())  
        {  
            cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;  
            it++;  
        }  
    }  
       
    //在k值设定的情况下,计算测试数据属于哪个lable,并输出   
    void KNN:: get_max_freq_label()  
    {  
        //将map_index_dis转换为vec_index_dis  
        vector<PAIR> vec_index_dis( map_index_dis.begin(),map_index_dis.end() );   
        //对vec_index_dis进行从低到高排序,以获得最近距离数据
        
        sort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue());  
      
        for(int i=0;i<k;i++)  
        {  
            cout<<"the index = "<<vec_index_dis[i].first<<" the distance = "<<vec_index_dis[i].second<<" the label = "<<labels[vec_index_dis[i].first]<<" the coordinate ( "<<dataSet[ vec_index_dis[i].first ][0]<<","<<dataSet[ vec_index_dis[i].first ][1]<<" )"<<endl;  
            //calculate the count of each label  
            map_label_freq[ labels[ vec_index_dis[i].first ]  ]++;  
        }  
      
        map<tLabel,int>::const_iterator map_it = map_label_freq.begin();  
        tLabel label;  
        int max_freq = 0;  
        //find the most frequent label  
        while( map_it != map_label_freq.end() )  
        {  
            if( map_it->second > max_freq )  
            {  
                max_freq = map_it->second;  
                label = map_it->first;  
            }  
            map_it++;  
        }  
        cout<<"The test data belongs to the "<<label<<" label"<<endl;  
    }  
      
    int main()  
    {  
        int k ;  
        cout<<"please input the k value : "<<endl;  
        cin>>k;  
        KNN knn(k);  
        knn.get_all_distance();  
        knn.get_max_freq_label();    
        return 0;  
    }  
     
     
  • 相关阅读:
    tftp服务器
    iw工具的使用
    六、【ioctl】应用程序和驱动程序中的ioctl
    位反转现象(Bit Flip)
    openwrt有线网卡的停用与开启
    寒假小记
    ARMLinux汇编到ADS汇编转换需要注意的问题
    c function pointer example
    (转)解决mysql“Access denied for user 'root'@'localhost'”
    c语言 面向对象的栈
  • 原文地址:https://www.cnblogs.com/tianwenjing123-456/p/14941521.html
Copyright © 2011-2022 走看看