zoukankan      html  css  js  c++  java
  • 机器学习基石笔记:Homework #4 Regularization&Validation相关习题

    原文地址:https://www.jianshu.com/p/3f7d4aa6a7cf

    问题描述

    图1 13-15
    图2 16-18
    图3 19-20

    程序实现

    # coding: utf-8
    
    import numpy as np
    import math
    import matplotlib.pyplot as plt
    
    def sign(x):
        if(x>=0):
            return 1
        else:
            return -1
    
    def read_data(dataFile):
        with open(dataFile,'r') as f:
            lines=f.readlines()
            data_list=[]
            for line in lines:
                line=line.strip().split()
                data_list.append([1.0] + [float(l) for l in line])
            dataArray=np.array(data_list)
            num_data=dataArray.shape[0]
            num_dim=dataArray.shape[1]-1
            dataX=dataArray[:,:-1].reshape((num_data,num_dim))
            dataY=dataArray[:,-1].reshape((num_data,1))
            return dataX,dataY
    
    def w_reg(dataX,dataY,namuta):
        num_dim=dataX.shape[1]
        dataX_T=np.transpose(dataX)
        tmp=np.dot(np.linalg.inv(np.dot(dataX_T,dataX)+namuta*np.eye(num_dim)),dataX_T)
        return np.dot(tmp,dataY)
    
    def pred(wREG,dataX):
        pred=np.dot(dataX,wREG)
        num_data=dataX.shape[0]
        for i in range(num_data):
            pred[i][0]=sign(pred[i][0])
        return pred
    
    def zero_one_cost(pred,dataY):
        return np.sum(pred!=dataY)/dataY.shape[0]
    
    
    if __name__=="__main__":
        # train
        dataX,dataY=read_data("hw4_train.dat")
        print("
    13")
        wREG=w_reg(dataX,dataY,namuta=10)
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        print("the Ein on the train set: ",Ein)
        # test
        testX,testY=read_data("hw4_test.dat")
        Eout=zero_one_cost(pred(wREG,testX),testY)
        print("the Eout on the test set: ",Eout)
    
        l=[2,1,0,-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]
    
        print("
    14")
        Ein_list=[]
        Eout_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(dataX,dataY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,dataX),dataY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("14.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    15")
        id_out = Eout_list.index(min(Eout_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eout_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eout")
        plt.savefig("15.png")
        print("the namuta with the minimun Eout: ", math.pow(10, l[id_out]))
    
        trainX=dataX[:120]
        trainY=dataY[:120]
        validX=dataX[120:]
        validY=dataY[120:]
    
        # validation
        print("
    16")
        Ein_list.clear()
        Eout_list.clear()
        Eval_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(trainX,trainY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,trainX),trainY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
            Eval_list.append(zero_one_cost(pred(wREG,validX),validY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("16.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    17")
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eval")
        plt.savefig("17.png")
        print("the namuta with the minimun Eval: ",math.pow(10,l[id_val]))
        print("the Eout on such namuta: ", Eout_list[id_val])
    
        print("
    18")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    
        # 5-fold cross validation
        print("
    19")
        Eval_list.clear()
        splX=np.split(dataX,5,axis=0)
        splY=np.split(dataY,5,axis=0)
        for j in l:
            Eval = 0
            namuta=math.pow(10,j)
            for i in range(5):
                li=[a for a in range(5)]
                li.pop(i)
                trainX=np.concatenate([splX[k] for k in li],axis=0)
                trainY=np.concatenate([splY[k] for k in li],axis=0)
                wREG=w_reg(trainX,trainY,namuta)
                Eval+=zero_one_cost(pred(wREG,splX[i]),splY[i])/5
            Eval_list.append(Eval)
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ecv")
        plt.savefig("19.png")
        print("the namuta with the minimun Ecv: ",math.pow(10,l[id_val]))
    
        print("
    20")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    

    运行结果

    13

    图4 13结果

    14

    图5 14结果1
    图6 14结果2

    15

    图7 15结果1
    图8 15结果2

    16

    图9 16结果1
    图10 16结果2

    17

    图11 17结果1
    图12 17结果2

    18

    图13 18结果

    19

    图14 19结果1
    图15 19结果2

    20

    图16 20结果

  • 相关阅读:
    为什么使用C#开发软件的公司和程序员都很少?
    使用Redis之前5个必须了解的事情
    这段代码为什么捕获不到异常呢?谁能给个解释,谢谢。
    git报错
    C# 常用类库(字符串处理,汉字首字母拼音,注入攻击,缓存操作,Cookies操作,AES加密等)
    你所不知道的 CSS 负值技巧与细节
    CSS 属性选择器的深入挖掘
    探秘 flex 上下文中神奇的自动 margin
    CSS 火焰?不在话下
    不可思议的纯 CSS 实现鼠标跟随效果
  • 原文地址:https://www.cnblogs.com/cherrychenlee/p/10800346.html
Copyright © 2011-2022 走看看