zoukankan      html  css  js  c++  java
  • 【作业一】林轩田机器学习基石

    作业方面,暂时只关注需要编程的题目了,用python完成代码。

    Q15~Q17应用的是传统PLA算法,给定的数据集也是保证线性可分的。

    代码需要完成的就是实现一个简单的PLA,并且“W = W + speed*yX”中的speed是可以配置的(即学习速率)

    代码1

    #encoding=utf8
    import sys
    import numpy as np
    import math
    
    if __name__ == '__main__':
        W = [ 0.0, 0.0, 0.0, 0.0, 0.0 ]
        halts = 0
        for line in open("train.dat").readlines():
            items = line.strip().split('	')
            y = items[1].strip()
            X = items[0].strip().split(' ')
            X.insert(0,1)
            # gurantee the length of W and X
            if ( len(W)!=len(X) ):
                sys.exit(-1)
            # initial score 0
            score = 0.0
            # calculate W'X
            for i in range(0,len(X)):
                score = score + float(X[i]) * float(W[i])
            print "score" + str(score)
            # transfer score to sign
            sign = 1 if score>0.0 else -1
            if sign != int(y) :
                halts = halts + 1
                for i in range(0,len(X)):
                    W[i] = float(W[i]) + float(y)*float(X[i])
        for w in W:
            print w
        print "halts:" + str(halts)

    代码2(随机打乱样本顺序)

    #encoding=utf8
    import sys
    import numpy as np
    import math
    from random import *
    
    if __name__ == '__main__':
        # params
        TIMES = 2000
        sum_halts = 0
        SPEED = 0.5
        # read raw data
        raw_data = []
        for line in open("train.dat").readlines():
            raw_data.append(line.strip())
        # iteratively
        a = Random()
        for i in range(0,TIMES):
            W = [ 0.0, 0.0, 0.0, 0.0, 0.0 ]
            halts = 0
            # randomly shuffle data
            a.seed(i)
            a.shuffle(raw_data)
            # pla process
            for line in raw_data:
                items = line.strip().split('	')
                y = items[1].strip()
                X = items[0].strip().split(' ')
                X.insert(0,1)
                # gurantee the length of W and X
                if ( len(W)!=len(X) ):
                    sys.exit(-1)
                # initial score 0
                score = 0.0
                # calculate W'X
                for i in range(0,len(X)):
                    score = score + float(X[i]) * float(W[i])
                # transfer score to sign
                sign = 1 if score>0.0 else -1
                if sign != int(y) :
                    halts = halts + 1
                    for i in range(0,len(X)):
                        W[i] = float(W[i]) + SPEED*float(y)*float(X[i])
            print "halts:" + str(halts)
            # accumulate sum of halts
            sum_halts = sum_halts + halts
        print "average halts:" + str(sum_halts/(TIMES-1))

    这几道题的可以得到的结论就是:如果更新学习的速率,打乱样本顺序,可能会对收敛的次数产生影响。

    另外,还有一个细节就是:一定不要忘记加上偏执W0(即常数项截距),否则会一直保持一个误差无法做到收敛。

    ==============================================

    作业Q18~Q20考查的是pocket pla

    即,train数据不是线性可分的情况(实际中也多是如此),改进成pocket pla的方法。

    之前一直没理解好pocket的意思,后来参考了讨论区的内容,理解了Pocket的意思。

    简而言之就是,“pocket不影响pla的正常运行,每轮W该更新还是要更新;pocket只需要维护历史出现的W中,在train_data上error最小的那个即可”

    #encoding=utf8
    import sys
    import numpy as np
    import math
    from random import *
    
    def error_on_data(data, W):
        error_W = 0
        for line in data:
            items = line.strip().split('	')
            y = items[1].strip()
            X = items[0].strip().split(' ')
            X.insert(0,1)
            # calculate scores of W
            score_W = 0.0
            for i in range(0,len(X)): score_W = score_W + float(X[i]) * float(W[i])
            # judge W 
            sign_W = 1 if score_W>0.0 else -1
            if sign_W != int(y) : error_W = error_W + 1
        return error_W
    
    def pocket_algorithm(train_data, r):
        best_W = [ 0, 0, 0, 0, 0 ]
        best_error = error_on_data(train_data, best_W)
        W = [ 0, 0, 0, 0, 0 ]
        rounds = 0
        while rounds<100:
            line = train_data[r.randint(0,len(train_data)-1)]
            items = line.strip().split('	')
            y = items[1].strip()
            X = items[0].strip().split(' ')
            X.insert(0,1)
            # initial score 0
            score = 0.0
            # calculate W'X
            for i in range(0,len(X)): score = score + float(X[i]) * float(W[i])
            # wrong judgement : transfer score to sign
            sign = 1 if score>0.0 else -1
            if sign != int(y) :
                rounds = rounds + 1
                for i in range(0,len(X)): W[i] = float(W[i]) + float(y)*float(X[i])
                # update best_W
                curr_error = error_on_data(train_data,W)
                print "curr_error:" + str(curr_error) + ",best_error:" + str(best_error)
                 if curr_error<best_error:
                    for i in range(0,len(best_W)): best_W[i]=W[i]
                    best_error = curr_error
        return best_W
        #return W
    
    if __name__ == '__main__':
        # read raw data
        train_data = []
        for line in open("train2.dat").readlines(): train_data.append(line.strip())
        test_data = []
        for line in open("test2.dat").readlines(): test_data.append(line.strip())
        # iteratively pocket algorithm
        iterative_times = 100
        total_error_times = 0
        r = Random()
        for i in range(0,iterative_times):
            # each round initialize a random seed
            r.seed(i)
            # conduct one round pocket algorithm
            W = pocket_algorithm(train_data, r)
            # accmulate error times
            error_times = error_on_data(test_data, W)
            total_error_times = total_error_times + error_times
        print str( (1.0*total_error_times)/(iterative_times*len(test_data)) )

     这个参考资料解释了Pocket 算法是怎么样运行的

    https://class.coursera.org/ntumlone-002/forum/thread?thread_id=79

  • 相关阅读:
    1.14 作业
    1.12作业
    1.9 作业 矩阵转置与输出九宫格
    1.8 作业
    1.7 作业 打印菱形
    1.5 作业
    1.4作业 不同的年龄,不同的问候语
    PHP语言 -- 发起流程
    PHP语言 -- 新建流程
    PHP语言 -- 权限
  • 原文地址:https://www.cnblogs.com/xbf9xbf/p/4578521.html
Copyright © 2011-2022 走看看