zoukankan      html  css  js  c++  java
  • Logistic 回归梯度上升优化函数

    In [183]:

     
     
     
     
     
    def loadDataSet():
        dataMat = []
        labelMat = []
        fr = open('testSet.txt')
        for line in fr.readlines():
            lineArr = line.strip().split()
            dataMat.append([1.0,float(lineArr[0]),float(lineArr[1])])
            labelMat.append(int(lineArr[2]))
        return dataMat,labelMat
        
     
     
    In [184]:
     
     
     
     
     
    def sigmoid(inX):
        return 1.0/(1+exp(-inX))
     
     
     

    批量梯度下降

    In [185]:
     
     
     
     
     
    def gradAscent(dataMatIn, classLabels):
        dataMatrix = mat(dataMatIn)
        labelMat = mat(classLabels).transpose()
        m,n = shape(dataMatrix)
        alpha = 0.001
        maxCycles = 500
        weights = ones((n,1))
        for k in range(maxCycles):
            h = sigmoid(dataMatrix*weights) #   h是一个矩阵
            error = (labelMat - h)
            weights = weights + alpha * dataMatrix.transpose() * error
        return weights 
     
     
     

    随机梯度下降

    In [186]:
     
     
     
     
     
    def stocGradAscent0(dataMatrix, classLabels):
        m,n = shape(dataMatrix)
        alpha = 0.01
        weights = ones(n)
        #weights = [0.1,0.1,0.1]
        for i in range(m):
            h = sigmoid(sum(dataMatrix[i]*weights))#  h是一个数值
            print dataMatrix[i]
            print weights
            print dataMatrix[i]*weights
            error = classLabels[i] - h
            weights = weights + alpha * error * dataMatrix[i]
        return weights
     
     
     

    sum()的参数是一个list 下面是改进的随机梯度上升算法:

    In [187]:
     
     
     
     
     
    def stocGradAscent1(dataMatrix, classLabels, numIter=150):
        m,n = shape(dataMatrix)
        weights = ones(n)
        #weights = [0.1,0.1,0.1]
        for j in range(numIter):
            dataIndex = range(m)
            for i in range(m):
                alpha = 4/(1.0+j+i)+0.01
                randIndex = int(random.uniform(0,len(dataIndex)))
                h = sigmoid(sum(dataMatrix[randIndex]*weights))#  h是一个数值
                error = classLabels[randIndex] - h
                weights = weights + alpha * error * dataMatrix[randIndex]
                del(dataIndex[randIndex])
        return weights
     
     
    In [188]:
     
     
     
     
     
    #import logRegres
     
     
    In [189]:
     
     
     
     
     
    dataArr,labelMat = loadDataSet()
     
     
    In [190]:
     
     
     
     
     
    #weights=gradAscent(dataArr,labelMat)
    weights=stocGradAscent1(array(dataArr),labelMat,500)
     
     
    In [191]:
     
     
     
     
     
    def plotBestFit(wei):
        import matplotlib.pyplot as plt
        weights = wei
        dataMat,labelMat = loadDataSet()
        dataArr = array(dataMat)
        n = shape(dataArr)[0]
        xcord1 = []; ycord1 = []
        xcord2 = []; ycord2 = []
        for i in range(n):
            if int(labelMat[i])==1:
                xcord1.append(dataArr[i,1]);ycord1.append(dataArr[i,2])
            else:
                xcord2.append(dataArr[i,1]);ycord2.append(dataArr[i,2])
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.scatter(xcord1,ycord1,s=30,c='red',marker='s')
        ax.scatter(xcord2,ycord2,s=30,c='green')
        x = arange(-3.0,3.0,0.1)
        y = (-weights[0]-weights[1]*x)/weights[2]
        ax.plot(x,y)
        plt.xlabel('X1')
        plt.ylabel('X2')
        plt.show()
     
     
     

    h = subplot(m,n,p)/subplot(mnp) 将figure划分为m×n块,在第p块创建坐标系,并返回它的句柄。当m,n,p<10时,可以简化为subplot(mnp)或者subplot mnp (注:subplot(m,n,p)或者subplot(mnp)此函数最常用:subplot是将多个图画到一个平面上的工具。其中,m表示是图排成m行,n表示图排成n列,也就是整个figure中有n个图是排成一行的,一共m行,如果第一个数字是2就是表示2行图。p是指你现在要把曲线画到figure中哪个图上,最后一个如果是1表示是从左到右第一个位置。 )

    In [192]:
     
     
     
     
     
    from numpy import *
    #reload
    print weights
    plotBestFit(weights)
  • 相关阅读:
    Java 面向对象异常处理,finally,覆盖时异常特点,package,import,包之间的访问(10)
    Java 面向对象 异常处理:RunTimeexception,try-catch,异常声明throws,自定义异常,throw和throws的区别,多异常处理(9)
    Java 面向对象概述原理: 多态、Object类,转型(8)
    Java 接口interface(7)
    Java 继承(extends)、抽象类(abstract)的特点用法原理(7)
    Java final 关键字的用法以及原理(7)
    df-V-du
    Arch-pacman-Tips-And-Tricks
    pacman-help
    Python-Version
  • 原文地址:https://www.cnblogs.com/zhizhan/p/5450526.html
Copyright © 2011-2022 走看看