zoukankan      html  css  js  c++  java
  • 机器学习实战教程(二):逻辑回归基础篇(上)

    一 逻辑回归是什么?

      首先虽然名字中带有回归两个字,但是这是一个不折不扣的分类算法。假设有一场足球赛,我们有两支球队的所有出场球员的信息、历史交锋成绩、比赛时间、主客场、裁判、天气等因素,根据这些因素去预测球队能否赢球。假设比赛结果记录为 y ,赢球标记为1,输球标记为 0,这就是个典型的二分类问题,可以用逻辑回归算法来解决。



    二 如何实现逻辑回归?

     首先需要找到一个预测函数,使其输出值在[0, 1]之间,然后选择一个基准值,如0.5, 当大于0.5另其为1,小于0.5为0,,这这里我们选择sigmoid函数:

      

      之所以选择sigmoid函数是因为:sigmoid函数是单位阶跃函数的一种变形,单位阶跃函数的问题在于该函数在跳跃点上从0 到 1 这个瞬间很难把握。


    三 公式推导

      

      整合成一个公式,就变成了如下公式:

    机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

      

      根据sigmoid函数的特性,我们可以做出如下的假设:

    机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

      我们可以把上述两个概率公式合二为一:

    机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

      合并出来的Cost,我们称之为代价函数(Cost Function)。当y等于1时,(1-y)项(第二项)为0;当y等于0时,y项(第一项)为0。为了简化问题,我们对整个表达式求对数,(将指数问题对数化是处理数学问题常见的方法):

    机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

      这个代价函数,是对于一个样本而言的。给定一个样本,我们就可以通过这个代价函数求出,样本所属类别的概率,而这个概率越大越好,所以也就是求解这个代价函数的最大值。既然概率出来了,那么最大似然估计也该出场了。假定样本与样本之间相互独立,那么整个样本集生成的概率即为所有样本生成概率的乘积,再将公式对数化,便可得到如下公式:

    机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

      其中,m为样本的总数,y(i)表示第i个样本的类别,x(i)表示第i个样本,需要注意的是θ是多维向量,x(i)也是多维向量。

      综上所述,满足J(θ)的最大的θ值即是我们需要求解的模型。

      怎么求解使J(θ)最大的θ值呢?因为是求最大值,所以我们需要使用梯度上升算法。如果面对的问题是求解使J(θ)最小的θ值,那么我们就需要使用梯度下降算法。面对我们这个问题,如果使J(θ) := -J(θ),那么问题就从求极大值转换成求极小值了,使用的算法就从梯度上升算法变成了梯度下降算法。


      数据集:

      首先将数据集里面的数据全部显示出来:

     1 def plotDataSet():
     2     dataMat, labelMat = loadDataSet()                                    #加载数据集
     3     dataArr = np.array(dataMat)                                            #转换成numpy的array数组
     4     n = np.shape(dataMat)[0]                                            #数据个数
     5     xcord1 = []; ycord1 = []                                            #正样本
     6     xcord2 = []; ycord2 = []                                            #负样本
     7     for i in range(n):                                                    #根据数据集标签进行分类
     8         if int(labelMat[i]) == 1:
     9             xcord1.append(dataArr[i,1]); ycord1.append(dataArr[i,2])    #1为正样本
    10         else:
    11             xcord2.append(dataArr[i,1]); ycord2.append(dataArr[i,2])    #0为负样本
    12     fig = plt.figure()
    13     ax = fig.add_subplot(111)                                            #添加subplot
    14     ax.scatter(xcord1, ycord1, s = 20, c = 'red', marker = 's',alpha=.5)#绘制正样本
    15     ax.scatter(xcord2, ycord2, s = 20, c = 'green',alpha=.5)            #绘制负样本
    16     plt.title('DataSet')                                                #绘制title
    17     plt.xlabel('x'); plt.ylabel('y')                                    #绘制label
    18     plt.show()                                                            #显示
    19  
    20 if __name__ == '__main__':
    21     plotDataSet()
    View Code

      

      2、训练算法

      在编写代码之前,让我们回顾下梯度上升迭代公式:

    机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

      将上述公式矢量化:

    机器学习实战教程(六):Logistic回归基础篇之梯度上升算法

    # -*- coding: utf-8 -*-
    """
    Created on Sun Jun 10 13:08:09 2018
    
    @author: Administrator
    """
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    def loadDataSet():
        dataMat = []                                                        #创建数据列表
        labelMat = []                                                        #创建标签列表
        fr = open('testSet.txt')                                            #打开文件   
        for line in fr.readlines():                                            #逐行读取
            lineArr = line.strip().split()                                    #去回车,放入列表
            dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])        #添加数据
            labelMat.append(int(lineArr[2]))                                #添加标签
        fr.close()                                                            #关闭文件
        return dataMat, labelMat                                            #返回
    
    def plotDataSet():
        dataMat, labelMat = loadDataSet()                                    #加载数据集
        dataArr = np.array(dataMat)                                            #转换成numpy的array数组
        n = np.shape(dataMat)[0]                                            #数据个数
        xcord1 = []; ycord1 = []                                            #正样本
        xcord2 = []; ycord2 = []                                            #负样本
        for i in range(n):                                                    #根据数据集标签进行分类
            if int(labelMat[i]) == 1:
                xcord1.append(dataArr[i,1]); ycord1.append(dataArr[i,2])    #1为正样本
            else:
                xcord2.append(dataArr[i,1]); ycord2.append(dataArr[i,2])    #0为负样本
        fig = plt.figure()
        ax = fig.add_subplot(111)                                            #添加subplot
        ax.scatter(xcord1, ycord1, s = 20, c = 'red', marker = 's',alpha=.5)#绘制正样本
        ax.scatter(xcord2, ycord2, s = 20, c = 'green',alpha=.5)            #绘制负样本
        plt.title('DataSet')                                                #绘制title
        plt.xlabel('x'); plt.ylabel('y')                                    #绘制label
        plt.show()                                                            #显示
     
    def sigmoid (inX):
        return 1.0 / (1 + np.exp(-inX))
    
    def gradAscent(dataMatIn, classLabels):
        dataMat = np.mat(dataMatIn)
        LabelMat = np.mat(classLabels).transpose()
        m, n = np.shape(dataMat)
        alpha = 0.001
        maxCycles = 500
        weights = np.ones((n, 1))
        
        for k in range(maxCycles):
            h = sigmoid(dataMat * weights)
            error = LabelMat - h
            weights = weights  + alpha * dataMat.transpose() * error
            
        return weights
        
    if __name__ == '__main__':
        dataMat, labelMat = loadDataSet()
        weights = gradAscent(dataMat, labelMat)
        print (weights)
    View Code

      将这个预测显示出来

        

  • 相关阅读:
    博客停止更新了,新博客地址见github
    SSH登录过程
    哈希表结构
    静态链接、动态链接
    编译、汇编、链接、加载
    IO复用 select epoll
    kali安装盘
    linux常用命令
    DDOS攻击防范系统的设计与实现
    20155202《网络对抗》Exp9 web安全基础实践
  • 原文地址:https://www.cnblogs.com/NaLaEur/p/9163125.html
Copyright © 2011-2022 走看看