zoukankan      html  css  js  c++  java
  • 感知机(perceptron)概念与实现

    感知机(perceptron)

    模型:

    简答的说由输入空间(特征空间)到输出空间的如下函数:

    [f(x)=sign(wcdot x+b) ]

    称为感知机,其中,(w)(b)表示的是感知机模型参数,(w in R^n)叫做权值,(b in R)叫做偏置(bias)
    感知机是一种线性分类模型属于判别模型。

    感知机的几何解释:线性方程:$$w cdot x + b = 0$$对应于特征空间(R^n)中的一个超平面S,这个超平面将特征空间分为两个部分,位于两部分的点(特征向量)分别被分为正负两类,超平面S被称为分离超平面。

    策略

    首先感知机的数据集是对线性可分的数据集的,所谓线性可分就是存在这么一个超平面可以把数据完全正确的划分到两边。
    感知机学习的目标就是要得出(w quad b),需要确定一个(经验)损失函数,并将损失函数最小化。对于这个损失函数我们最容易想到的就是误分类的总数,但是我们也要注意到这个不能够是(w quad b)的可导连续函数,所以我们选择点误分类的点到超平面的距离作为损失函数。最终得到损失函数定义为:

    [L(w,b)=-sum_{x_i in M}y_i(w cdot x_i+b) ]

    算法

    这里我们用的是随机梯度下降法,思想是:首先随机选择一个分离超平面(w_0,b_0)然后用随机梯度下降不断最小化目标函数,最终得到完全正确的分类效果

    感知机学习算法的原始形式

    1.选择初始值(w_0,b_0)
    2.在训练集中选取数据((x_i,y_i))
    3.如果(y_i(w cdot x_i+b)le 0)

    [w gets w+eta y_i x_i$$$$b gets b+ eta y_i ]

    4.跳转至2,直至训练集中没有误分类点

    代码实现:

    w = [0, 0]
    b = 0
    
    def createDataSet():
        """
        create dataset for test
        """
        return [[(3, 3), 1], [(4, 3), 1], [(1, 1), -1]]
    
    def update(item):
        """
        update with stochastic gradient descent
        """
        global w, b
        w[0] += item[1] * item[0][0]
        w[1] += item[1] * item[0][1]
        b += item[1]
    
    def cal(item):
        """
        calculate the functional distance between 'item' an the dicision surface. output yi(w*xi+b).
        """
        res = 0
        for i in range(len(item[0])):
            res += item[0][i] * w[i]
        res += b
        res *= item[1]
        return res
    
    
    def check():
        """
        check if the hyperplane can classify the examples correctly
        """
        flag = False
        for item in training_set:
            if cal(item) <= 0:
                flag = True
                update(item)
        if not flag:
            print "RESULT: w: " + str(w) + " b: " + str(b)
        return flag
    
    if __name__ == "__main__":
        training_set = createDataSet()
        while check():
            pass
    

    感知机学习算法的对偶形式:

    1.(alpha gets 0,b gets 0)
    2.在训练集中选取数据((x_i,y_i))
    3.如果(y_i(sum_{j=1}^{N}alpha_jy_ix_jcdot x_i+b) le 0)

    [alpha_i gets alpha_i+eta$$$$b gets b+eta y_i ]

    4.转至2,直到没有误分类的数据

    代码实现

    这里主要是有一个叫做gram矩阵的东西,因为我们发现下面的计算过程中都是以内积的形式存在的,所以说这部分的值可以先算出来。(G=[x_i*x_j])

    import numpy as np
    
    def createDataSet():
        """
        create data set for test
        """
        return np.array([[[3, 3], 1], [[4, 3], 1], [[1, 1], -1]])
    
    def cal_gram():
        """
        calculate the Gram matrix
        """
        g = np.empty((len(training_set), len(training_set)), np.int)
        for i in range(len(training_set)):
            for j in range(len(training_set)):
                g[i][j] = np.dot(training_set[i][0], training_set[j][0])
        return g
    
    def update(i):
        """
        update parameters using stochastic gradient descent
        """
        global alpha, b
        alpha[i] += 1
        b = b + y[i]
    
    def cal(i):
        """
        cal
        """
        global alpha, b, x, y
        res = np.dot(alpha * y, Gram[i])
        res = (res + b) * y[i]
        return res
    
    def check():
        """
        check if the hyperplane can classify the examples correctly
        """
        global alpha, b, x, y
        flag = False
        for i in range(len(training_set)):
            if cal(i) <= 0:
                flag = True
                update(i)
        if not flag:
            w = np.dot(alpha * y, x)
            print "RESULT: w: " + str(w) + " b: " + str(b)
            return False
        return True
    
    
    if __name__ == "__main__":
        training_set = createDataSet()
        alpha = np.zeros(len(training_set), np.float)
        b = 0.0
        Gram = None
        y = np.array(training_set[:, 1])
        x = np.empty((len(training_set), 2), np.float)
        for i in range(len(training_set)):
            x[i] = training_set[i][0]
            Gram = cal_gram()
        while check():
            pass
    

    本文链接
    以上内容参考自《统计学习方法》

  • 相关阅读:
    Spring Boot基础
    MyBatis开启二级缓存
    MyBatis逆向工程
    html实现“加入收藏”代码
    vue-router 基本使用
    vue 脚手架安装
    webpack入门 webpack4常见出错之处
    $.ajax()方法详解
    防止网页被嵌套
    H5字符实体参考
  • 原文地址:https://www.cnblogs.com/MrLJC/p/4428443.html
Copyright © 2011-2022 走看看