zoukankan      html  css  js  c++  java
  • 机器学习基石笔记:Homework #4 Regularization&Validation相关习题

    原文地址:https://www.jianshu.com/p/3f7d4aa6a7cf

    问题描述

    图1 13-15
    图2 16-18
    图3 19-20

    程序实现

    # coding: utf-8
    
    import numpy as np
    import math
    import matplotlib.pyplot as plt
    
    def sign(x):
        if(x>=0):
            return 1
        else:
            return -1
    
    def read_data(dataFile):
        with open(dataFile,'r') as f:
            lines=f.readlines()
            data_list=[]
            for line in lines:
                line=line.strip().split()
                data_list.append([1.0] + [float(l) for l in line])
            dataArray=np.array(data_list)
            num_data=dataArray.shape[0]
            num_dim=dataArray.shape[1]-1
            dataX=dataArray[:,:-1].reshape((num_data,num_dim))
            dataY=dataArray[:,-1].reshape((num_data,1))
            return dataX,dataY
    
    def w_reg(dataX,dataY,namuta):
        num_dim=dataX.shape[1]
        dataX_T=np.transpose(dataX)
        tmp=np.dot(np.linalg.inv(np.dot(dataX_T,dataX)+namuta*np.eye(num_dim)),dataX_T)
        return np.dot(tmp,dataY)
    
    def pred(wREG,dataX):
        pred=np.dot(dataX,wREG)
        num_data=dataX.shape[0]
        for i in range(num_data):
            pred[i][0]=sign(pred[i][0])
        return pred
    
    def zero_one_cost(pred,dataY):
        return np.sum(pred!=dataY)/dataY.shape[0]
    
    
    if __name__=="__main__":
        # train
        dataX,dataY=read_data("hw4_train.dat")
        print("
    13")
        wREG=w_reg(dataX,dataY,namuta=10)
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        print("the Ein on the train set: ",Ein)
        # test
        testX,testY=read_data("hw4_test.dat")
        Eout=zero_one_cost(pred(wREG,testX),testY)
        print("the Eout on the test set: ",Eout)
    
        l=[2,1,0,-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]
    
        print("
    14")
        Ein_list=[]
        Eout_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(dataX,dataY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,dataX),dataY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("14.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    15")
        id_out = Eout_list.index(min(Eout_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eout_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eout")
        plt.savefig("15.png")
        print("the namuta with the minimun Eout: ", math.pow(10, l[id_out]))
    
        trainX=dataX[:120]
        trainY=dataY[:120]
        validX=dataX[120:]
        validY=dataY[120:]
    
        # validation
        print("
    16")
        Ein_list.clear()
        Eout_list.clear()
        Eval_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(trainX,trainY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,trainX),trainY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
            Eval_list.append(zero_one_cost(pred(wREG,validX),validY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("16.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    17")
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eval")
        plt.savefig("17.png")
        print("the namuta with the minimun Eval: ",math.pow(10,l[id_val]))
        print("the Eout on such namuta: ", Eout_list[id_val])
    
        print("
    18")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    
        # 5-fold cross validation
        print("
    19")
        Eval_list.clear()
        splX=np.split(dataX,5,axis=0)
        splY=np.split(dataY,5,axis=0)
        for j in l:
            Eval = 0
            namuta=math.pow(10,j)
            for i in range(5):
                li=[a for a in range(5)]
                li.pop(i)
                trainX=np.concatenate([splX[k] for k in li],axis=0)
                trainY=np.concatenate([splY[k] for k in li],axis=0)
                wREG=w_reg(trainX,trainY,namuta)
                Eval+=zero_one_cost(pred(wREG,splX[i]),splY[i])/5
            Eval_list.append(Eval)
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ecv")
        plt.savefig("19.png")
        print("the namuta with the minimun Ecv: ",math.pow(10,l[id_val]))
    
        print("
    20")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    

    运行结果

    13

    图4 13结果

    14

    图5 14结果1
    图6 14结果2

    15

    图7 15结果1
    图8 15结果2

    16

    图9 16结果1
    图10 16结果2

    17

    图11 17结果1
    图12 17结果2

    18

    图13 18结果

    19

    图14 19结果1
    图15 19结果2

    20

    图16 20结果

  • 相关阅读:
    分享免费的jQuery Mobile Wordpress主题 jQMobile
    分享50个使用非比寻常导航菜单设计的创意网站
    分享一个超酷javascript全屏幻灯导航(fullscreen slide navigation)
    分享一款jQuery的UI插件:Ninja UI
    使用jQuery开发一个超酷的倒计时效果
    分享使用jQuery和CSS实现的一个超酷缩略图悬浮逼近效果
    Nosql数据库教程之初探MongoDB 第一部分
    分享一个使一行文字变形产生弯曲弧度特效的jQuery插件 Arctext.js
    分享2011年12月的11个最棒的jQuery插件
    分享8个最新的javascript脚本资源
  • 原文地址:https://www.cnblogs.com/cherrychenlee/p/10800346.html
Copyright © 2011-2022 走看看