zoukankan      html  css  js  c++  java
  • 机器学习基石笔记:Homework #4 Regularization&Validation相关习题

    原文地址:https://www.jianshu.com/p/3f7d4aa6a7cf

    问题描述

    图1 13-15
    图2 16-18
    图3 19-20

    程序实现

    # coding: utf-8
    
    import numpy as np
    import math
    import matplotlib.pyplot as plt
    
    def sign(x):
        if(x>=0):
            return 1
        else:
            return -1
    
    def read_data(dataFile):
        with open(dataFile,'r') as f:
            lines=f.readlines()
            data_list=[]
            for line in lines:
                line=line.strip().split()
                data_list.append([1.0] + [float(l) for l in line])
            dataArray=np.array(data_list)
            num_data=dataArray.shape[0]
            num_dim=dataArray.shape[1]-1
            dataX=dataArray[:,:-1].reshape((num_data,num_dim))
            dataY=dataArray[:,-1].reshape((num_data,1))
            return dataX,dataY
    
    def w_reg(dataX,dataY,namuta):
        num_dim=dataX.shape[1]
        dataX_T=np.transpose(dataX)
        tmp=np.dot(np.linalg.inv(np.dot(dataX_T,dataX)+namuta*np.eye(num_dim)),dataX_T)
        return np.dot(tmp,dataY)
    
    def pred(wREG,dataX):
        pred=np.dot(dataX,wREG)
        num_data=dataX.shape[0]
        for i in range(num_data):
            pred[i][0]=sign(pred[i][0])
        return pred
    
    def zero_one_cost(pred,dataY):
        return np.sum(pred!=dataY)/dataY.shape[0]
    
    
    if __name__=="__main__":
        # train
        dataX,dataY=read_data("hw4_train.dat")
        print("
    13")
        wREG=w_reg(dataX,dataY,namuta=10)
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        print("the Ein on the train set: ",Ein)
        # test
        testX,testY=read_data("hw4_test.dat")
        Eout=zero_one_cost(pred(wREG,testX),testY)
        print("the Eout on the test set: ",Eout)
    
        l=[2,1,0,-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]
    
        print("
    14")
        Ein_list=[]
        Eout_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(dataX,dataY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,dataX),dataY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("14.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    15")
        id_out = Eout_list.index(min(Eout_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eout_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eout")
        plt.savefig("15.png")
        print("the namuta with the minimun Eout: ", math.pow(10, l[id_out]))
    
        trainX=dataX[:120]
        trainY=dataY[:120]
        validX=dataX[120:]
        validY=dataY[120:]
    
        # validation
        print("
    16")
        Ein_list.clear()
        Eout_list.clear()
        Eval_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(trainX,trainY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,trainX),trainY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
            Eval_list.append(zero_one_cost(pred(wREG,validX),validY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("16.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    17")
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eval")
        plt.savefig("17.png")
        print("the namuta with the minimun Eval: ",math.pow(10,l[id_val]))
        print("the Eout on such namuta: ", Eout_list[id_val])
    
        print("
    18")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    
        # 5-fold cross validation
        print("
    19")
        Eval_list.clear()
        splX=np.split(dataX,5,axis=0)
        splY=np.split(dataY,5,axis=0)
        for j in l:
            Eval = 0
            namuta=math.pow(10,j)
            for i in range(5):
                li=[a for a in range(5)]
                li.pop(i)
                trainX=np.concatenate([splX[k] for k in li],axis=0)
                trainY=np.concatenate([splY[k] for k in li],axis=0)
                wREG=w_reg(trainX,trainY,namuta)
                Eval+=zero_one_cost(pred(wREG,splX[i]),splY[i])/5
            Eval_list.append(Eval)
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ecv")
        plt.savefig("19.png")
        print("the namuta with the minimun Ecv: ",math.pow(10,l[id_val]))
    
        print("
    20")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    

    运行结果

    13

    图4 13结果

    14

    图5 14结果1
    图6 14结果2

    15

    图7 15结果1
    图8 15结果2

    16

    图9 16结果1
    图10 16结果2

    17

    图11 17结果1
    图12 17结果2

    18

    图13 18结果

    19

    图14 19结果1
    图15 19结果2

    20

    图16 20结果

  • 相关阅读:
    react项目中如何解决同时需要多个请求问题
    jq+ajax+bootstrap改写一个动态分页的表格
    Window7+vs2008+QT环境搭建
    mssql charindex
    解决NTLDR is missing,系统无法启动的方法
    基于三汇语音卡的呼叫中心开发(一)
    Wince 或Windows平台 C#调用Bitmap对象后资源应该如何释放
    Anki:插件开发
    java.lang.ClassNotFoundException: com.opensymphony.xwork2.util.ValueStack
    struts2中action之间的一种跳转
  • 原文地址:https://www.cnblogs.com/cherrychenlee/p/10800346.html
Copyright © 2011-2022 走看看