zoukankan      html  css  js  c++  java
  • 机器学习基石笔记:Homework #2 Decision Stump相关习题

    原文地址:http://www.jianshu.com/p/4bc01760ac20

    问题描述

    图1 16-18
    图2 19-20

    程序实现

    17-18

    # coding: utf-8
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    def sign(n):
        if(n>0):
            return 1
        else:
            return -1
    
    def gen_data():
        data_X=np.random.uniform(-1,1,(20,1))# [-1,1)
        data_Y=np.zeros((20,1))
        idArray=np.random.permutation([i for i in range(20)])
        for i in range(20):
            if(i<20*0.2):
                data_Y[idArray[i]][0]=-sign(data_X[idArray[i]][0])
            else:
                data_Y[idArray[i]][0] = sign(data_X[idArray[i]][0])
        data=np.concatenate((data_X,data_Y),axis=1)
        return data
    
    def decision_stump(dataArray):
        minErrors=20
        min_s_theta_list=[]
        num_data=dataArray.shape[0]
        data=dataArray.tolist()
        data.sort(key=lambda x:x[0])
        for s in [-1.0,1.0]:
            for i in range(num_data):
                if(i==num_data-1):
                    theta=(data[i][0]+1.0)/2
                else:
                    theta=(data[i][0]+data[i+1][0])/2
                errors=0
                for i in range(20):
                    pred=s*sign(data[i][0]-theta)
                    if(pred!=data[i][1]):
                        errors+=1
                if(minErrors>errors):
                    minErrors=errors
                    min_s_theta_list=[]
                elif(minErrors<errors):
                    continue
                min_s_theta_list.append((s, theta))
        i=np.random.randint(low=0,high=len(min_s_theta_list))
        min_s,min_theta=min_s_theta_list[i]
        return minErrors,min_s,min_theta
    
    def computeEinEout(minErrors,min_s,min_theta):
        Ein=minErrors/20
        Eout=0.5+0.3*min_s*(abs(min_theta)-1)
        return Ein,Eout
    
    
    if __name__=="__main__":
        Ein_list=[]
        Eout_list=[]
        for i in range(5000):
            dataArray=gen_data()
            minErrors,min_s,min_theta=decision_stump(dataArray)
            Ein,Eout=computeEinEout(minErrors,min_s,min_theta)
            Ein_list.append(Ein)
            Eout_list.append(Eout)
    
        # show results
        # 17 & 18
        print("the average Ein: ",sum(Ein_list)/5000)
        print("the average Eout: ",sum(Eout_list)/5000)
    
        plt.figure(figsize=(16,6))
        plt.subplot(121)
        plt.hist(Ein_list)
        plt.xlabel("Ein")
        plt.ylabel("frequency")
        plt.subplot(122)
        plt.hist(Eout_list)
        plt.xlabel("Eout")
        plt.ylabel("frequency")
        plt.savefig("EinEout.png")
    

    19-20

    # coding: utf-8
    
    import numpy as np
    
    def read_data(dataFile):
        with open(dataFile, 'r') as file:
            data_list = []
            for line in file.readlines():
                line = line.strip().split()
                data_list.append([float(l) for l in line])
            data_array = np.array(data_list)
            return data_array
    
    def predict(s,theta,dataX):
        num_data=dataX.shape[0]
        res=s*np.sign(dataX-theta)
        return res
    
    def decision_stump(dataArray):
        min_s_theta_list=[]
        num_data=dataArray.shape[0]
        minErrors=num_data
        data=dataArray.tolist()
        data.sort(key=lambda x:x[0])
        dataArray=np.array(data)
        dataX=dataArray[:,0].reshape(num_data,1)
        dataY=dataArray[:,1].reshape(num_data,1)
        for s in [-1.0,1.0]:
            for i in range(num_data):
                if(i==num_data-1):
                    theta=(dataX[i][0]*2+1)/2
                else:
                    theta=(dataX[i][0]+dataX[i+1][0])/2
                pred=predict(s,theta,dataX)
                errors=np.sum(pred!=dataY)
                if(minErrors>errors):
                    minErrors=errors
                    min_s_theta_list=[]
                elif(minErrors<errors):
                    continue
                min_s_theta_list.append((s, theta))
        i=np.random.randint(low=0,high=len(min_s_theta_list))
        min_s,min_theta=min_s_theta_list[i]
        return minErrors,min_s,min_theta
    
    def best_of_best(candidate):
        candidate.sort(key=lambda x:x[1])
        counts=0
        for i in range(len(candidate)):
            if(candidate[i][1]!=candidate[0][1]):
                break
            counts+=1
        i=np.random.randint(low=0,high=counts)
        return candidate[i][0],candidate[i][1],candidate[i][2],candidate[i][3]
    
    
    if __name__=="__main__":
        data_array=read_data("hw2_train.dat")
        num_data=data_array.shape[0]
        num_dim=data_array.shape[1]-1
        candidate=[]
        dataY=data_array[:,-1].reshape(num_data,1)
        for i in range(num_dim):
            dataX=data_array[:,i].reshape(num_data,1)
            min_errors,min_s,min_theta=decision_stump(np.concatenate((dataX,dataY),axis=1))
            candidate.append([i,min_errors,min_s,min_theta])
        min_id,min_errors,min_s,min_theta=best_of_best(candidate)
        print("the optimal decision stump:
    ","s: ",min_s,"
    theta: ",min_theta)
        print("the Ein of the optimal decision stump:
    ",min_errors/num_data)
    
        test_array=read_data("hw2_test.dat")
        num_test=test_array.shape[0]
        testY=test_array[:,-1].reshape(num_test,1)
        num_dim=test_array.shape[1]-1
        testX=test_array[:,min_id].reshape(num_test,1)
        pred=predict(min_s,min_theta,testX)
        print("the Eout of the optimal decision stump by Etest:
    ",np.sum(pred!=testY)/num_test)
    

    运行结果

    17-18

    图3 17-18结果1
    图4 17-18结果2

    19-20

    图5 19-20结果

  • 相关阅读:
    windows窗口消息内部处理机制
    iPhone and iPad Development GUI Kits, Stencils and Icons
    【转】windbg 调试经典文章(常用)
    atl和mfc
    开发IDA pro图形界面插件
    ida常用插件
    为Visual studio 2008 添加汇编工程模板
    常用软件汇总
    BOOL EnumInternetExplorer( ProcessWebBrowser pHander )
    同年龄的牛人博客
  • 原文地址:https://www.cnblogs.com/cherrychenlee/p/10796683.html
Copyright © 2011-2022 走看看