zoukankan      html  css  js  c++  java
  • 感知机(perceptron)概念与实现

    感知机(perceptron)

    模型:

    简答的说由输入空间(特征空间)到输出空间的如下函数:

    [f(x)=sign(wcdot x+b) ]

    称为感知机,其中,(w)(b)表示的是感知机模型参数,(w in R^n)叫做权值,(b in R)叫做偏置(bias)
    感知机是一种线性分类模型属于判别模型。

    感知机的几何解释:线性方程:$$w cdot x + b = 0$$对应于特征空间(R^n)中的一个超平面S,这个超平面将特征空间分为两个部分,位于两部分的点(特征向量)分别被分为正负两类,超平面S被称为分离超平面。

    策略

    首先感知机的数据集是对线性可分的数据集的,所谓线性可分就是存在这么一个超平面可以把数据完全正确的划分到两边。
    感知机学习的目标就是要得出(w quad b),需要确定一个(经验)损失函数,并将损失函数最小化。对于这个损失函数我们最容易想到的就是误分类的总数,但是我们也要注意到这个不能够是(w quad b)的可导连续函数,所以我们选择点误分类的点到超平面的距离作为损失函数。最终得到损失函数定义为:

    [L(w,b)=-sum_{x_i in M}y_i(w cdot x_i+b) ]

    算法

    这里我们用的是随机梯度下降法,思想是:首先随机选择一个分离超平面(w_0,b_0)然后用随机梯度下降不断最小化目标函数,最终得到完全正确的分类效果

    感知机学习算法的原始形式

    1.选择初始值(w_0,b_0)
    2.在训练集中选取数据((x_i,y_i))
    3.如果(y_i(w cdot x_i+b)le 0)

    [w gets w+eta y_i x_i$$$$b gets b+ eta y_i ]

    4.跳转至2,直至训练集中没有误分类点

    代码实现:

    w = [0, 0]
    b = 0
    
    def createDataSet():
        """
        create dataset for test
        """
        return [[(3, 3), 1], [(4, 3), 1], [(1, 1), -1]]
    
    def update(item):
        """
        update with stochastic gradient descent
        """
        global w, b
        w[0] += item[1] * item[0][0]
        w[1] += item[1] * item[0][1]
        b += item[1]
    
    def cal(item):
        """
        calculate the functional distance between 'item' an the dicision surface. output yi(w*xi+b).
        """
        res = 0
        for i in range(len(item[0])):
            res += item[0][i] * w[i]
        res += b
        res *= item[1]
        return res
    
    
    def check():
        """
        check if the hyperplane can classify the examples correctly
        """
        flag = False
        for item in training_set:
            if cal(item) <= 0:
                flag = True
                update(item)
        if not flag:
            print "RESULT: w: " + str(w) + " b: " + str(b)
        return flag
    
    if __name__ == "__main__":
        training_set = createDataSet()
        while check():
            pass
    

    感知机学习算法的对偶形式:

    1.(alpha gets 0,b gets 0)
    2.在训练集中选取数据((x_i,y_i))
    3.如果(y_i(sum_{j=1}^{N}alpha_jy_ix_jcdot x_i+b) le 0)

    [alpha_i gets alpha_i+eta$$$$b gets b+eta y_i ]

    4.转至2,直到没有误分类的数据

    代码实现

    这里主要是有一个叫做gram矩阵的东西,因为我们发现下面的计算过程中都是以内积的形式存在的,所以说这部分的值可以先算出来。(G=[x_i*x_j])

    import numpy as np
    
    def createDataSet():
        """
        create data set for test
        """
        return np.array([[[3, 3], 1], [[4, 3], 1], [[1, 1], -1]])
    
    def cal_gram():
        """
        calculate the Gram matrix
        """
        g = np.empty((len(training_set), len(training_set)), np.int)
        for i in range(len(training_set)):
            for j in range(len(training_set)):
                g[i][j] = np.dot(training_set[i][0], training_set[j][0])
        return g
    
    def update(i):
        """
        update parameters using stochastic gradient descent
        """
        global alpha, b
        alpha[i] += 1
        b = b + y[i]
    
    def cal(i):
        """
        cal
        """
        global alpha, b, x, y
        res = np.dot(alpha * y, Gram[i])
        res = (res + b) * y[i]
        return res
    
    def check():
        """
        check if the hyperplane can classify the examples correctly
        """
        global alpha, b, x, y
        flag = False
        for i in range(len(training_set)):
            if cal(i) <= 0:
                flag = True
                update(i)
        if not flag:
            w = np.dot(alpha * y, x)
            print "RESULT: w: " + str(w) + " b: " + str(b)
            return False
        return True
    
    
    if __name__ == "__main__":
        training_set = createDataSet()
        alpha = np.zeros(len(training_set), np.float)
        b = 0.0
        Gram = None
        y = np.array(training_set[:, 1])
        x = np.empty((len(training_set), 2), np.float)
        for i in range(len(training_set)):
            x[i] = training_set[i][0]
            Gram = cal_gram()
        while check():
            pass
    

    本文链接
    以上内容参考自《统计学习方法》

  • 相关阅读:
    用于主题检测的临时日志(594fb726-af0b-400d-b647-8b1d1b477d72
    返璞归真vc++之字符类型
    DIV居中
    程序员职业生涯
    枚举进程句柄
    不使用mutex设计模式解决并发访问cache
    服务器权重分配算法
    xmemecached中的一致性hash算法
    安卓课堂练习
    pythonPTA---分支循环与集合7-1 jmu-python-韩信点兵 (20分) 7-2 打印数字矩形 (10分) 7-3 成绩统计 (10分) 7-4 找列表中最大元素的下标 7-5 删除列表中的重复值
  • 原文地址:https://www.cnblogs.com/MrLJC/p/4428443.html
Copyright © 2011-2022 走看看