zoukankan      html  css  js  c++  java
  • 机器学习技法笔记:Homework #8 kNN&RBF&k-Means相关习题

    原文地址:https://www.jianshu.com/p/1db700f866ee

    问题描述

    图1 问题描述
    图2 11-18
    图3 19-20

    程序实现

    # kNN_RBFN.py
    # coding:utf-8
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    def ReadData(dataFile):
    
        with open(dataFile, 'r') as f:
            lines = f.readlines()
            data_list = []
            for line in lines:
                line = line.strip().split()
                data_list.append([float(l) for l in line])
            dataArray = np.array(data_list)
            return dataArray
    
    
    def sign(n):
    
        if(n>=0):
            return 1
        else:
            return -1
    
    
    def kNN(k,trainArray,dataX):
        num_data=dataX.shape[0]
        predY=np.zeros((num_data,))
        for n in range(num_data):
            distArray=np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1)
            id_list=np.argsort(distArray,axis=0).tolist()[:k]
            for i in id_list:
                predY[n]+=trainArray[i,-1]
            predY[n]=sign(predY[n])
        return predY
    
    
    def GetZeroOneError(predY,dataY):
        return (predY!=dataY).sum()/dataY.shape[0]
    
    
    def plot_bar_chart(X,Y,nameX,nameY,saveName):
        plt.figure(figsize=(10,6))
        plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
        for (c,w) in zip(X,Y):
            plt.text(c,w*1.03,str(round(w,4)))
        plt.xlabel(nameX)
        plt.ylabel(nameY)
        plt.xlim(X[0]-1,X[-1]+1)
        plt.xticks(X)
        plt.ylim(0,1)
        plt.title(nameY+" versus "+nameX)
        plt.savefig(saveName)
        return
    
    
    def RBFNetwork(k,gamma,trainArray,dataX):
        num_data=dataX.shape[0]
        predY=np.zeros((num_data,))
        for n in range(num_data):
            gaussianDistArray=np.exp(-gamma*np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1))
            id_list=np.argsort(gaussianDistArray,axis=0).tolist()[:k]
            for i in id_list:
                predY[n]+=trainArray[i,-1]
            predY[n]=sign(predY[n])
        return predY
    
    
    if __name__=="__main__":
    
        dataArray=ReadData("hw8_train.dat")
        testArray=ReadData("hw8_test.dat")
        k_list=[1,3,5,7,9]
        ein_list=[]
        eout_list=[]
        for k in k_list:
            predY=kNN(k,dataArray,dataArray[:,:-1])
            ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
            predY=kNN(k,dataArray,testArray[:,:-1])
            eout_list.append(GetZeroOneError(predY,testArray[:,-1]))
    
        # 12
        plot_bar_chart(k_list,ein_list,nameX="k",nameY="Ein(gk-nbor)",saveName="12.png")
    
        # 14
        plot_bar_chart(k_list,eout_list,nameX='k',nameY="Eout(gk-bor)",saveName="14.png")
    
    
        gamma_list=[-3,-1,0,1,2]
        ein_list=[]
        eout_list=[]
        for gamma in gamma_list:
            predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,dataArray[:,:-1])
            ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
            predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,testArray[:,:-1])
            eout_list.append(GetZeroOneError(predY,testArray[:,-1]))
    
        # 16
        plot_bar_chart(X=gamma_list,Y=ein_list,nameX="log10(gamma)",nameY="Ein(guniform)",saveName="16.png")
    
        # 18
        plot_bar_chart(X=gamma_list,Y=eout_list,nameX="log10(gamma)",nameY="Eout(guniform)",saveName="18.png")
    
    
    # kMeans.py
    # coding:utf-8
    
    from numpy import random
    from kNN_RBFN import *
    
    
    def kMeans(t,k,dataArray):
        num_data=dataArray.shape[0]
        random.seed(t)
        centreIDList=random.randint(0,num_data,k).tolist()
        nowCentreArray=dataArray[centreIDList,:]
        tmpCentreArray=np.array(nowCentreArray)
        ein=1000000
        nowEin=ein-1
        dict={}
        while(nowEin<ein):
            ein=nowEin
            dict = {}
            for n in range(num_data):
                distArray=np.sum((nowCentreArray-dataArray[n,:])**2,axis=1)
                minID=np.argmin(distArray)
                tmpCentreArray[minID]=(tmpCentreArray[minID]+dataArray[n,:])/2
                try:
                    dict[minID].append(dataArray[n,:])
                except:
                    dict[minID]=[]
                    dict[minID].append(dataArray[n,:])
            nowCentreArray=np.array(tmpCentreArray)
            nowEin=GetEin(nowCentreArray,dict)
        return nowCentreArray,dict
    
    
    def GetEin(nowCentreArray,dict):
        k=nowCentreArray.shape[0]
        ein=0
        for i in range(k):
            if i not in dict.keys():
                continue
            data=np.array(dict[i])
            ein+=np.average(np.sum((data-nowCentreArray[i])**2,axis=1))
        return ein
    
    
    def plot_bar_chart(X,Y,nameX,nameY,saveName):
        plt.figure(figsize=(10,6))
        plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
        for (c,w) in zip(X,Y):
            plt.text(c,w*1.03,str(round(w,4)))
        plt.xlabel(nameX)
        plt.ylabel(nameY)
        plt.xlim(X[0]-1,X[-1]+1)
        plt.xticks(X)
        plt.title(nameY+" versus "+nameX)
        plt.savefig(saveName)
        return
    
    
    if __name__=="__main__":
    
        dataArray=ReadData("hw8_nolabel_train.dat")
        k_list=[2,4,6,8,10]
        ein_list=[]
        for k in k_list:
            ein=0
            for t in range(500):
                nowCentreArray,dict=kMeans(t,k,dataArray)
                ein+=GetEin(nowCentreArray,dict)
            ein_list.append(ein/500)
    
        plot_bar_chart(k_list,ein_list,nameX="k",nameY="the average Ein over 500 experiments",saveName="20.png")
    
    

    运行结果

    图4 12结果
    图5 14结果
    图6 16结果
    图7 18结果
    图8 20结果

  • 相关阅读:
    HTML和CSS之HTML(记录一2015.3.30)
    jquery学习记录三(表单选择器)
    jquery学习记录四(操作DOM元素)
    jquery学习记录二(过滤性选择器)
    jquery学习记录一(基础选择器)
    聚集索引和非聚集索引
    git命令
    4,gps信号与地图匹配算法
    3,gps定位原理及格式
    2,地图数据分析-地图数据转换成导航引擎数据
  • 原文地址:https://www.cnblogs.com/cherrychenlee/p/10803832.html
Copyright © 2011-2022 走看看