zoukankan      html  css  js  c++  java
  • logistic regression教程1

    实现线性拟合

    我们用python2.7实现上一篇的推导结果。请先安装python matplotlib包和numpy包。

    具体代码如下:

    #!/usr/bin/env python 
    #! -*- coding:utf-8 -*-
    
    import matplotlib.pyplot as plt
    from numpy import *
    
    #创建数据集
    def load_dataset():
        n = 100
        X = [[1, 0.005*xi] for xi in range(1, 100)]
        Y = [2*xi[1]  for xi in X]
        return X, Y
    
    #梯度下降法求解线性回归
    def grad_descent(X, Y):
        X = mat(X)
        Y = mat(Y)
        row, col = shape(X)
        alpha = 0.001
        maxIter = 5000
        W = ones((1, col))
        for k in range(maxIter):
            W = W + alpha * (Y - W*X.transpose())*X
        return W
    
    def main():
        X, Y = load_dataset()
        W = grad_descent(X, Y)
        print "W = ", W
    
        #绘图
        x = [xi[1] for xi in X]
        y = Y
        plt.plot(x, y, marker="*")
        xM = mat(X)
        y2 = W*xM.transpose()
        y22 = [y2[0,i] for i in range(y2.shape[1]) ]
        plt.plot(x, y22, marker="o")
        plt.show()
    
    if __name__ == "__main__":
        main()
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    代码超级简单,load_dataset函数创建了一个y=2x的数据集,grad_descent函数求解优化问题。

    在grad_descent里多了两个小东西,alpha是学习速率,一般取0.001~0.01,太大可能会导致震荡,求解不稳定。maxIter是最大迭代次数,它决定结果的精确度,通常是越大越好,但越大越耗时,所以通常需要试算以下,也可以另外写一个判定标准,比如当YWXT小于多少的时候就不再迭代。

    我们来看一下效果: 
    当maxIter=5时,拟合结果是这样的: 

    如果maxIter=50,拟合结果是这样的: 

    如果maxIter=500,拟合结果是这样的: 

    如果maxIter=1000,拟合结果是这样的: 

    如果maxIter=5000,拟合结果是这样的: 

    5000次的结果几乎完美,两条曲线图形重合。就酱。 
    本篇到此结束,下一篇,我们开始把logistic函数加进来,推导logistic regression。

  • 相关阅读:
    Homebrew 更改国内阿里源
    Java数组以及内存分配
    Django-Scrapy生成后端json接口
    MySQL必知必会(1-12章)
    招聘网站爬虫模板
    ubuntu之jupyter notebook配置
    服务器基本配置(ubuntu)
    Typora+PicGo+码云Gitee搭建本地博客环境
    redis缓存雪崩,穿透,击穿。整理篇
    鼠标修复升级记录(下)
  • 原文地址:https://www.cnblogs.com/developer-ios/p/5014887.html
Copyright © 2011-2022 走看看