zoukankan      html  css  js  c++  java
  • knn-matlab实现


    close all;
    clc;

    %%算法实现
    %step1、初始化训练集、测试集、K值
    %创建一个三维矩阵,二维表示同一类下的二维坐标点,第三维表示类别

    trainData1=[0 0;0.1 0.3;0.2 0.1;0.2 0.2];%第一类训练数据
    trainData2=[1 0;1.1 0.3;1.2 0.1;1.2 0.2];%第二类训练数据
    trainData3=[0 1;0.1 1.3;0.2 1.1;0.2 1.2];%第三类训练数据

    trainData(:,:,1)=trainData1;%设置第一类测试数据
    trainData(:,:,2)=trainData2;%设置第二类测试数据
    trainData(:,:,3)=trainData3;%设置第三类测试数据

    trainDim=size(trainData);%获取训练集的维数
    %空间三维矩阵 4列 2列 3排
    % trainDim =
    %
    % 4 2 3

    testData=[1.6 0.3];%设置1个测试点

    K=7;

    %%分别计算测试集中各个点与每个训练集中的点的欧氏距离
    %把测试点扩展成矩阵
    testData_rep=repmat(testData,4,1)

    %初始化值
    % A = repmat(10,3,2)
    % A = 3×2
    %
    % 10 10
    % 10 10
    % 10 10
    %

    % testData_rep =
    %
    % 1.6000 0.3000
    % 1.6000 0.3000
    % 1.6000 0.3000
    % 1.6000 0.3000


    %设置三个二维矩阵存放测试集与测试点的扩展矩阵的差值平方


    for i=1:trainDim(3)
    diff1=(trainData(:,:,1)-testData_rep).^2;
    diff2=(trainData(:,:,2)-testData_rep).^2;
    diff3=(trainData(:,:,3)-testData_rep).^2;
    end

    %设置三个一维数组存放欧式距离
    distance1=(diff1(:,1)+diff1(:,2)).^0.5 ;
    %取diff1的横坐标,取diff1的纵坐标 计算第一类点集和测试点的距离
    distance2=(diff2(:,1)+diff2(:,2)).^0.5;
    distance3=(diff3(:,1)+diff3(:,2)).^0.5;

    %将三个一维数组合成一个二维矩阵
    temp=[distance1 distance2 distance3] %所有的距离都列在一起 size(temp)=4*3 三个样本点集到测试点的距离组成的矩阵
    %将这个二维矩阵转换为一维数组
    distance=reshape(temp,1,3*4);
    %对距离进行排序
    distance_sort=sort(distance) %把12哥距离排序 从小到大


    num1=0;%第一类出现的次数
    num2=0;%第二类出现的次数
    num3=0;%第三类出现的次数
    sum=0;%sum1,sum2,sum3的和
    for i=1:K %取所有距离前K个最小的
    for j=1:4%每一类训练集只有4哥样本点

    if distance1(j)==distance_sort(i) %依次把每一个样本点到测试点的距离依次和所有的距离从小到大进行比较
    %i不动,j变四次 如果能找到相同的,说明这些点都是比较近的,如果没有出现distance的点,那说明这个点太远了
    num1=num1+1; %如果发现
    % disp('*****')
    distance1(j);
    num1 %最近的优先级最高,所以最先比较,但是一个点不能说明问题,要把最近的那一类点都total进来,才能发现属于哪一类
    end
    if distance2(j)==distance_sort(i)
    num2=num2+1;
    % disp('////')
    distance2(j); %最小的点在第二类 依次累加num_i的个数
    num2;
    end
    if distance3(j)==distance_sort(i)
    num3=num3+1;
    %disp('-----')
    distance3(j);
    sum3
    end
    end
    sum=num1+num2+num3;
    if sum>=K
    break;
    end
    end

    class=[num1 num2 num3];

    %classname=find(class(1,:)==max(class));
    max(class)%最大的那一类就是第二类,sum3等于0 sum1=3
    class(1,:)

    classname=find(class(1,:)==max(class))%最大的那一类就是第二类,所以把最近的最多的 归为第二类

    fprintf('测试点(%f %f)属于第%d类',testData(1),testData(2),classname);%最重要的一句话

    %%使用绘图将训练集点和测试集点绘画出来
    figure(1);
    hold on;
    for i=1:4
    plot(trainData1(i,1),trainData1(i,2),'*');
    plot(trainData2(i,1),trainData2(i,2),'o');
    plot(trainData3(i,1),trainData3(i,2),'>');
    end


    plot(testData(1),testData(2),'x');

    text(0.1,0.1,'第一类');
    text(1.1,0.1,'第二类');
    text(0.1,1,'第三类');

  • 相关阅读:
    JavaScript Validator 报错
    JSP项目_Web路径_磁盘物理路径
    TreaponseHeader
    TrequestHeader
    HTML学习笔记1
    SQL文摘:DATE_TRUNC: A SQL Timestamp Function You Can Count On
    SQL文摘:Writing Subqueries in SQL
    Python文摘:Requests (Adavanced Usage)
    Python文摘:Requests
    Python文摘:More About Unicode in Python 2 and 3
  • 原文地址:https://www.cnblogs.com/china520/p/11615624.html
Copyright © 2011-2022 走看看