zoukankan      html  css  js  c++  java
  • 机器学习基石笔记:Homework #4 Regularization&Validation相关习题

    原文地址:https://www.jianshu.com/p/3f7d4aa6a7cf

    问题描述

    图1 13-15
    图2 16-18
    图3 19-20

    程序实现

    # coding: utf-8
    
    import numpy as np
    import math
    import matplotlib.pyplot as plt
    
    def sign(x):
        if(x>=0):
            return 1
        else:
            return -1
    
    def read_data(dataFile):
        with open(dataFile,'r') as f:
            lines=f.readlines()
            data_list=[]
            for line in lines:
                line=line.strip().split()
                data_list.append([1.0] + [float(l) for l in line])
            dataArray=np.array(data_list)
            num_data=dataArray.shape[0]
            num_dim=dataArray.shape[1]-1
            dataX=dataArray[:,:-1].reshape((num_data,num_dim))
            dataY=dataArray[:,-1].reshape((num_data,1))
            return dataX,dataY
    
    def w_reg(dataX,dataY,namuta):
        num_dim=dataX.shape[1]
        dataX_T=np.transpose(dataX)
        tmp=np.dot(np.linalg.inv(np.dot(dataX_T,dataX)+namuta*np.eye(num_dim)),dataX_T)
        return np.dot(tmp,dataY)
    
    def pred(wREG,dataX):
        pred=np.dot(dataX,wREG)
        num_data=dataX.shape[0]
        for i in range(num_data):
            pred[i][0]=sign(pred[i][0])
        return pred
    
    def zero_one_cost(pred,dataY):
        return np.sum(pred!=dataY)/dataY.shape[0]
    
    
    if __name__=="__main__":
        # train
        dataX,dataY=read_data("hw4_train.dat")
        print("
    13")
        wREG=w_reg(dataX,dataY,namuta=10)
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        print("the Ein on the train set: ",Ein)
        # test
        testX,testY=read_data("hw4_test.dat")
        Eout=zero_one_cost(pred(wREG,testX),testY)
        print("the Eout on the test set: ",Eout)
    
        l=[2,1,0,-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]
    
        print("
    14")
        Ein_list=[]
        Eout_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(dataX,dataY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,dataX),dataY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("14.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    15")
        id_out = Eout_list.index(min(Eout_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eout_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eout")
        plt.savefig("15.png")
        print("the namuta with the minimun Eout: ", math.pow(10, l[id_out]))
    
        trainX=dataX[:120]
        trainY=dataY[:120]
        validX=dataX[120:]
        validY=dataY[120:]
    
        # validation
        print("
    16")
        Ein_list.clear()
        Eout_list.clear()
        Eval_list=[]
        for i in l:
            namuta=math.pow(10,i)
            wREG=w_reg(trainX,trainY,namuta)
            Ein_list.append(zero_one_cost(pred(wREG,trainX),trainY))
            Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
            Eval_list.append(zero_one_cost(pred(wREG,validX),validY))
        id_in=Ein_list.index(min(Ein_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ein")
        plt.savefig("16.png")
        print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
        print("the Eout on such namuta: ", Eout_list[id_in])
    
        print("
    17")
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Eval")
        plt.savefig("17.png")
        print("the namuta with the minimun Eval: ",math.pow(10,l[id_val]))
        print("the Eout on such namuta: ", Eout_list[id_val])
    
        print("
    18")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    
        # 5-fold cross validation
        print("
    19")
        Eval_list.clear()
        splX=np.split(dataX,5,axis=0)
        splY=np.split(dataY,5,axis=0)
        for j in l:
            Eval = 0
            namuta=math.pow(10,j)
            for i in range(5):
                li=[a for a in range(5)]
                li.pop(i)
                trainX=np.concatenate([splX[k] for k in li],axis=0)
                trainY=np.concatenate([splY[k] for k in li],axis=0)
                wREG=w_reg(trainX,trainY,namuta)
                Eval+=zero_one_cost(pred(wREG,splX[i]),splY[i])/5
            Eval_list.append(Eval)
        id_val=Eval_list.index(min(Eval_list))
        plt.figure()
        plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
        plt.xlabel("namuta")
        plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
        plt.ylabel("Ecv")
        plt.savefig("19.png")
        print("the namuta with the minimun Ecv: ",math.pow(10,l[id_val]))
    
        print("
    20")
        wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
        Ein=zero_one_cost(pred(wREG,dataX),dataY)
        Eout = zero_one_cost(pred(wREG, testX), testY)
        print("Ein: ",Ein)
        print("Eout: ",Eout)
    

    运行结果

    13

    图4 13结果

    14

    图5 14结果1
    图6 14结果2

    15

    图7 15结果1
    图8 15结果2

    16

    图9 16结果1
    图10 16结果2

    17

    图11 17结果1
    图12 17结果2

    18

    图13 18结果

    19

    图14 19结果1
    图15 19结果2

    20

    图16 20结果

  • 相关阅读:
    codeforce1214E Petya and Construction Set
    codeforces1214D Treasure Island
    CCPC2019网络赛1002 array (主席树)
    POJ2442
    计算机网络-应用层(3)Email应用
    计算机网络-应用层(2)FTP协议
    计算机网络-应用层(1)Web应用与HTTP协议
    算法-排序(1)k路平衡归并与败者树
    算法-搜索(6)B树
    RSA加密算法和SSH远程连接服务器
  • 原文地址:https://www.cnblogs.com/cherrychenlee/p/10800346.html
Copyright © 2011-2022 走看看