zoukankan      html  css  js  c++  java
  • 【StatLearn】统计学习中knn算法实验(2)

    接着统计学习中knn算法实验(1)的内容

    Problem:

    1. Explore the data before classification using summary statistics or visualization
    2. Pre-process the data (such as denoising, normalization, feature selection, …) 
    3. Try other distance metrics or distance-based voting
    4. Try other dimensionality reduction methods
    5. How to set the k value, if not using cross validation? Verify your idea
    问题:
    1. 在对数据分类之前使用对数据进行可视化处理
    2. 预处理数据(去噪,归一化,数据选择)
    3. 在knn算法中使用不同的距离计算方法
    4. 使用其他的降维算法
    5. 如何在不使用交叉验证的情况下设置k值

    使用Parallel coordinates plot做数据可视化,首先对数据进行归一化处理,数据的动态范围控制在[0,1]。注意归一化的处理针对的是每一个fearture。




    通过对图的仔细观察,我们挑选出重叠度比较低的feature来进行fearture selection,feature selection实际上是对数据挑选出更易区分的类型作为下一步分类算法的数据。我们挑选出feature序号为(1)、(2)、(5)、(6)、(7)、(10)的feature。个人认为,feature selection是一种简单而粗暴的降维和去噪的操作,但是可能效果会很好。 

    根据上一步的操作,从Parallel coordinates上可以看出,序号为(1)、(2)、(5)、(6)、(7)、(10)这几个feature比较适合作为classify的feature。我们选取以上几个feature作knn,得到的结果如下:

    当K=1 的时候,Accuracy达到了85.38%,并且相比于简单的使用knn或者PCA+knn的方式,Normalization、Featrure Selection的方法使得准确率大大提升。我们也可以使用不同的feature搭配,通过实验得到更好的结果。


    MaxAccuracy= 0.8834 when k=17 (Normalization+FeartureSelection+KNN)

     

    试验中,我们使用了两种不同的Feature Selection 策略,选用较少fearture的策略对分类的准确率还是有影响的,对于那些从平行坐标看出的不那么好的fearture,对分类还是有一定的帮助的。
    在较小的k值下,Feature Selection的结果要比直接采用全部Feature的结果要好。这也体现了在相对纯净的数据下,较小的k值能够获得较好的结果,这和直观感觉出来的一致。
    我们再尝试对数据进行进一步的预处理操作,比如denoising。
    数据去噪的方法利用对Trainning数据进行一个去处最大最小边缘值的操作,我们认为,对于一个合适的feature,它的数据应该处于一个合理的范围中,过大或者过小的数据都将是异常的。

    Denoising的代码如下:

     

    function[DNData]=DataDenoising(InputData,KillRange)
    DNData=InputData;
    %MedianData=median(DNData);
    for i=2:size(InputData,2)
       [temp,DNIndex]=sort(DNData(:,i));
       DNData=DNData(DNIndex(1+KillRange:end-KillRange),:);
    end




     

    采用LLE作为降维的手段,通过和以上的几种方案作对比,如下:


     

    MaxAccuracy= 0.9376 when K=23 (LLE dimensionality reduction to 2)

    关于LLE算法,参见这篇论文

     

    • Nonlinear dimensionality reduction by locally linear embedding.Sam Roweis & Lawrence Saul.Science, v.290 no.5500 , Dec.22, 2000. pp.2323--2326.
    以及项目主页:


    源代码:

    StatLearnProj.m

    clear;
    data=load('wine.data.txt');
    %calc 5-folder knn
    Accuracy=[];
    for i=1:5
        Test=data(i:5:end,:);
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);
        Trainning=setdiff(data,Test,'rows');
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
    end
    AccuracyKNN=mean(Accuracy,1);
    
    %calc PCA
    Accuracy=[];
    %PCA
    [Coeff,Score,Latent]=princomp(data(:,2:end));
    dataPCA=[data(:,1),Score(:,1:6)];
    Latent
    for i=1:5
        Test=dataPCA(i:5:end,:);
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);
        Trainning=setdiff(dataPCA,Test,'rows');
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
    end
    AccuracyPCA=mean(Accuracy,1);
    BarData=[AccuracyKNN;AccuracyPCA];
    bar(1:2:51,BarData');
    
    [D,I]=sort(AccuracyKNN,'descend');
    D(1)
    I(1)
    [D,I]=sort(AccuracyPCA,'descend');
    D(1)
    I(1)
    
    %pre-processing data
    %Normalization
    labs1={'1)Alcohol','(2)Malic acid','3)Ash','4)Alcalinity of ash'};
    labs2={'5)Magnesium','6)Total phenols','7)Flavanoids','8)Nonflavanoid phenols'};
    labs3={'9)Proanthocyanins','10)Color intensity','11)Hue','12)OD280/OD315','13)Proline'};
    uniData=[];
    for i=2:size(data,2)
        uniData=cat(2,uniData,(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i))));
    end
    figure();
    parallelcoords(uniData(:,1:4),'group',data(:,1),'labels',labs1);
    figure();
    parallelcoords(uniData(:,5:8),'group',data(:,1),'labels',labs2);
    figure();
    parallelcoords(uniData(:,9:13),'group',data(:,1),'labels',labs3);
    
    %denoising
    
    %Normalization && Feature Selection
    uniData=[data(:,1),uniData];
    %Normalization all feature
    
    for i=1:5
        Test=uniData(i:5:end,:);
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);
        Trainning=setdiff(uniData,Test,'rows');
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
    end
    AccuracyNorm=mean(Accuracy,1);
    
    %KNN PCA Normalization
    BarData=[AccuracyKNN;AccuracyPCA;AccuracyNorm];
    bar(1:2:51,BarData');
    
    %Normalization& FS 1 2 5 6 7 10 we select 1 2 5 6 7 10 feature 
    FSData=uniData(:,[1 2 3 6 7 8 11]);
    size(FSData)
    for i=1:5
        Test=FSData(i:5:end,:);
        Trainning=setdiff(FSData,Test,'rows');
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
    end
    AccuracyNormFS1=mean(Accuracy,1);
    
    %Normalization& FS 1 6 7 
    FSData=uniData(:,[1 2 7 8]);
    for i=1:5
        Test=FSData(i:5:end,:);
        Trainning=setdiff(FSData,Test,'rows');
        TestData=Test(:,2:end);
        TestLabel=Test(:,1); 
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
    end
    AccuracyNormFS2=mean(Accuracy,1);
    figure();
    BarData=[AccuracyNorm;AccuracyNormFS1;AccuracyNormFS2];
    bar(1:2:51,BarData');
    
    [D,I]=sort(AccuracyNorm,'descend');
    D(1)
    I(1)
    [D,I]=sort(AccuracyNormFS1,'descend');
    D(1)
    I(1)
    [D,I]=sort(AccuracyNormFS2,'descend');
    D(1)
    I(1)
    %denoiding
    %Normalization& FS 1 6 7 
    FSData=uniData(:,[1 2 7 8]);
    for i=1:5
        Test=FSData(i:5:end,:);
        Trainning=setdiff(FSData,Test,'rows');
        Trainning=DataDenoising(Trainning,2);
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);     
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel));
    end
    AccuracyNormFSDN=mean(Accuracy,1);
    figure();
    hold on
    plot(1:2:51,AccuracyNormFSDN);
    plot(1:2:51,AccuracyNormFS2,'r');
    
    %other distance metrics
    
    Dist='cityblock';
    for i=1:5
        Test=uniData(i:5:end,:);
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);
        Trainning=setdiff(uniData,Test,'rows');
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist));
    end
    AccuracyNormCity=mean(Accuracy,1);
    
    BarData=[AccuracyNorm;AccuracyNormCity];
    figure();
    bar(1:2:51,BarData');
    
    [D,I]=sort(AccuracyNormCity,'descend');
    D(1)
    I(1)
    
    %denoising
    FSData=uniData(:,[1 2 7 8]);
    Dist='cityblock';
    for i=1:5
        Test=FSData(i:5:end,:);
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);
        Trainning=setdiff(FSData,Test,'rows');
        Trainning=DataDenoising(Trainning,3);
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist));
    end
    AccuracyNormCityDN=mean(Accuracy,1);
    figure();
    hold on
    plot(1:2:51,AccuracyNormCityDN);
    plot(1:2:51,AccuracyNormCity,'r');
    
    %call lle
    
    data=load('wine.data.txt');
    uniData=[];
    for i=2:size(data,2)
        uniData=cat(2,uniData,(data(:,i)-min(data(:,i)))/(max(data(:,i))-min(data(:,i))));
    end
    uniData=[data(:,1),uniData];
    LLEData=lle(uniData(:,2:end)',5,2);
    %size(LLEData)
    LLEData=LLEData';
    LLEData=[data(:,1),LLEData];
    
    Accuracy=[];
    for i=1:5
        Test=LLEData(i:5:end,:);
        TestData=Test(:,2:end);
        TestLabel=Test(:,1);
        Trainning=setdiff(LLEData,Test,'rows');
        Trainning=DataDenoising(Trainning,2);
        TrainningData=Trainning(:,2:end);
        TrainningLabel=Trainning(:,1);
        Accuracy=cat(1,Accuracy,CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,'cityblock'));
    end
    AccuracyLLE=mean(Accuracy,1);
    [D,I]=sort(AccuracyLLE,'descend');
    D(1)
    I(1)
    
    BarData=[AccuracyNorm;AccuracyNormFS2;AccuracyNormFSDN;AccuracyLLE];
    figure();
    bar(1:2:51,BarData');
    
    save('ProcessingData.mat');
    
        

    CalcAccuracy.m

    function Accuracy=CalcAccuracy(TestData,TestLabel,TrainningData,TrainningLabel)
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %calculate the accuracy of classify
    %TestData:M*D matrix D stand for dimension,M is sample
    %TrainningData:T*D matrix
    %TestLabel:Label of TestData
    %TrainningLabel:Label of Trainning Data
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    CompareResult=[];
    for k=1:2:51
        ClassResult=knnclassify(TestData,TrainningData,TrainningLabel,k);
        CompareResult=cat(2,CompareResult,(ClassResult==TestLabel));
    end
    SumCompareResult=sum(CompareResult,1);
    Accuracy=SumCompareResult/length(CompareResult(:,1));

    CalcAccuracyPlus.m

    function Accuracy=CalcAccuracyPlus(TestData,TestLabel,TrainningData,TrainningLabel,Dist)
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %just as CalcAccuracy,but add distance metrics
    %calculate the accuracy of classify
    %TestData:M*D matrix D stand for dimension,M is sample
    %TrainningData:T*D matrix
    %TestLabel:Label of TestData
    %TrainningLabel:Label of Trainning Data
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    CompareResult=[];
    for k=1:2:51
        ClassResult=knnclassify(TestData,TrainningData,TrainningLabel,k,Dist);
        CompareResult=cat(2,CompareResult,(ClassResult==TestLabel));
    end
    SumCompareResult=sum(CompareResult,1);
    Accuracy=SumCompareResult/length(CompareResult(:,1));




  • 相关阅读:
    bat脚本%cd%和%~dp0的区别
    java测试程序运行时间
    != 的注意事项
    [转载] iptables 防火墙设置
    .NET 创建 WebService
    [转载] 学会使用Web Service上(服务器端访问)~~~
    cygwin 安装 apt-cyg
    在Element节点上进行Xpath
    Element节点输出到System.out
    [转载] 使用StAX解析xml
  • 原文地址:https://www.cnblogs.com/pangblog/p/3402651.html
Copyright © 2011-2022 走看看