zoukankan      html  css  js  c++  java
  • C语言实现knn

     以后写代码一定要谨慎,提高代码的正确率。 

    /***************************************
     * 1.初始化距离为最大值
     * 2.计算未知样本和每个训练样本的距离为dist
     * 3.得到目前k个最邻近样本中的最大距离maxdist
     * 4.如果dist小于maxdist,则将改训练样本作为k-最近邻样本
     * 5.重复2、3、4,直至未知样本和训练样本的距离都算完
     * 6.统计k个最近邻样本中每个类别出现的次数
     * 7.选择出现频率最大的类别作为未知样本的类别
     * *****************************************/
    
    #include <stdio.h>
    #include <math.h>
    #include <stdlib.h>
    #include <string.h>
    #define MAX 0x7fffffff
    #define K 3
    
    double  cal_dist(int n,double *x,double *y)
    {
        double sum = 0.0;
        int i =0;
        for(i=0;i<n;i++)
        {
            sum += pow((x[i]-y[i]),2.0);
        }
        return sqrt(sum);
    }
    
    void bubbleSort(double **array,int count,int flag)
    {
        int i = count,j;
        double temp;
        while(i>0)
        {
            for(j=0;j<i-1;j++)
            {
                if(flag == 0)
                {
                    if(array[0][j] > array[0][j+1])
                    {
                        temp = array[0][j];
                        array[0][j] = array[0][j+1];
                        array[0][j+1] = temp;
                        temp = array[1][j];
                        array[1][j] = array[1][j+1];
                        array[1][j+1] = temp;
                    }
    
                }
                else if(flag == 1)
                {
                    if(array[1][j] > array[1][j+1])
                    {
                        temp = array[1][j];
                        array[1][j] = array[1][j+1];
                        array[1][j+1] = temp;
                        temp = array[0][j];
                        array[0][j] = array[0][j+1];
                        array[0][j+1] = temp;
                    }
                }
            }
            i--;
        }
        return;
    }
    int main()
    {
        int n,m;
        FILE *fp;
        fp = fopen("/data.txt","r");
        fscanf(fp,"N=%d,D=%d",&n,&m);
        printf("N=%d,D=%d
    ",n,m);
        double  **array;
        array = (double **)malloc(n*sizeof(double));
        array[0] = (double *)malloc(n*m*sizeof(double));
        int h,j = 0,i =0;
        for(i=1;i<n;i++)
        {
            array[i] = array[i-1] + m;                            
        }
        for(i=0;i<n;i++)
        {
            for(j=0;j<m;j++)
            {
                fscanf(fp,"%lf",&array[i][j]);                                            
            }                                
        }
        double **temp;
        temp = (double **)malloc(2*sizeof(double));
        temp[0] = (double *)malloc(2*K*sizeof(double));
        for(i=1;i<2;i++)
        {
            temp[i] = temp[i-1] + K;
        }
        for(i=0;i<2;i++)
        {
            for(j=0;j<K;j++)
            {
                temp[i][j] = MAX*0.1;
            }
        }
        double *testdata;
        double max_dist = 0.0;
        double distance = 0.0;
        double tmp = 0.0;
        testdata=(double *)malloc((m-1)*sizeof(double));
        printf("input test data containing %d numbers:
    ",m-1);
        for(i=0;i<(m-1);i++)
        {
                fscanf(fp,"%lf",&testdata[i]);
        }
        close(fp);
        while(1)
        {
            for(i=0;i<K;i++)
            {
                if(K > n) break;
                temp[0][i] = cal_dist(n,testdata,array[i]);
                temp[1][i] = array[i][m-1];
            }
            for(i=0;i<K;i++)
            {
                printf("%4lf,%4lf
    ",temp[0][i],temp[1][i]);
            }
            printf("
    ");
            bubbleSort(temp,K,0);
            max_dist = temp[0][K-1];
            for(i=K;i<n;i++)
            {
                distance = cal_dist(n,testdata,array[i]);
                if(max_dist > distance)
                {
                    for(j=0;j<K;j++)
                    {
                        if(distance < temp[0][j])
                        {
                            for(h=K-1;h>j;h--)
                            {
                                temp[0][h] = temp[0][h-1];
                                temp[1][h] = temp[1][h-1];
                            }
                        }
                        temp[0][j] = distance;
                        temp[1][j] = array[i][m-1];
                    }
                }
                max_dist = temp[0][K-1];
            }
            bubbleSort(temp,K,1);
            break;
        }
    
    
        int value_label = 0;
        int count = 0;
        int flag = 0;
        for(i=0;i<K-1;i++)
        {
            if(temp[1][i] != temp[1][i+1])
            {
                if(flag > count)
                {
                    flag = count;
                    value_label = temp[1][i];
                    count =1;
                }
            }
            else
            {
                count ++;
            }
        }
        if(count > flag)
        {
            value_label = temp[1][K-1];
            flag = count;
        }
        printf("Predict message is %d
    ",value_label);
        return 0;
    }
  • 相关阅读:
    win10 访问远程文件夹 此共享需要过时的SMB1协议 你不能访问此共享文件夹
    Navicat 1142 SELECT command denied to user 'sx'@'xxx' for table 'user'
    MySQL 密码参数配置与修改 validate_password
    MySQL 命令行下更好的显示查询结果
    MySQL 数据库的存储结构
    MySQL实验 内连接优化order by+limit 以及添加索引再次改进
    MySQL实验 子查询优化双参数limit
    MySQL 索引结构 hash 有序数组
    MySQL 树形索引结构 B树 B+树
    hbase2.1.9 centos7 完全分布式 搭建随记
  • 原文地址:https://www.cnblogs.com/chenyang920/p/7398226.html
Copyright © 2011-2022 走看看