zoukankan      html  css  js  c++  java
  • 简单线性回归(梯度下降法) python实现

    grad_desc

     

    简单线性回归(梯度下降法)

     

    0.引入依赖

    In [1]:
    import numpy as np
    import matplotlib.pyplot as plt
    
     

    1.导入数据

    In [34]:
    points = np.genfromtxt("data.csv",delimiter=",")
    #points
    #提取points中的两列数据,分别作为x,y
    x=points[:,0];
    y=points[:,1];
    
    #用plt画出散点图
    plt.scatter(x,y)
    plt.show()
    
     
     

    2.定义损失函数

    In [35]:
    # 损失函数是系数的函数,另外还要传入数据的x,y
    def compute_cost(w,b,points):
        total_cost=0
        M =len(points)
        for i in range(M):
            x=points[i,0]
            y=points[i,1]
            total_cost += (y-w*x-b)**2
        return total_cost/M #一除都是浮点 两个除号是地板除,整型。 如 3 // 4
    
     

    3.定义模型的超参数

    In [52]:
    alpha = 0.0000001
    initial_w = 0
    initial_b = 0
    num_iter =20
    
     

    4.定义核心梯度下降算法函数

    In [37]:
    def grad_desc(points,initial_w,initial_b,alpha,num_iter):
        w = initial_w
        b = initial_b
        # 定义一个list保存所有的损失函数值,用来显示下降过程。
        cost_list=[]
        for i in range(num_iter):
            cost_list.append(compute_cost(w,b,points))
            w,b= step_grad_desc(w,b,alpha,points)
        return [w,b,cost_list]
    
    def step_grad_desc(current_w,current_b,alpha,points):
        sum_grad_w=0
        sum_grad_b=0
        M=len(points)
        #对每个点代入公式求和
        for i in range(M):
            x= points[i,0]
            y= points[i,1]
            sum_grad_w += (current_w * x +current_b -y) *x
            sum_grad_b +=  current_w * x +current_b -y
        #用公式求当前梯度
        grad_w=2/M * sum_grad_w
        grad_b=2/M * sum_grad_b
        
        #梯度下降,更新当前的w和b
        updated_w = current_w- alpha * grad_w
        updated_b = current_b -alpha * grad_b
        return updated_w,updated_b
    
     

    5.测试,运行梯度下降算法

    In [54]:
    w,b,cost_list= grad_desc(points,initial_w,initial_b,alpha,num_iter)
    print ("w is :",w)
    print ("b is :",b)
    
    cost = compute_cost(w,b,points)
    
    print("cost_list:",cost_list)
    print("cost is:",cost)
    plt.plot(cost_list)
    
     
    w is : 1.9845988031472985
    b is : 0.0004970348345541671
    cost_list: [30684366.833333332, 9539857.724899232, 2973884.507279095, 934962.5312039739, 301819.09812286275, 105210.00196432497, 44157.269403835446, 25198.654436632325, 19311.463701577555, 17483.323048238828, 16915.633228203948, 16739.349345812665, 16684.608171772015, 16667.609475636003, 16662.3308954798, 16660.691745647422, 16660.182742743986, 16660.02468269902, 16659.975600422167, 16659.960358850854]
    cost is: 16659.95562578394
    
    Out[54]:
    [<matplotlib.lines.Line2D at 0x1218cd978>]
     
    In [55]:
    plt.scatter(x,y)
    
    pred_y= w*x+b
    
    plt.plot(x,pred_y,c='r')
    
    Out[55]:
    [<matplotlib.lines.Line2D at 0x121984940>]
     
    In [ ]:
     
  • 相关阅读:
    Visual Studio 2010的活动,有兴趣的朋友可以来参加
    .NET 业务框架开发实战之九 Mapping属性原理和验证规则的实现策略
    Javascript 返回上一页
    Entity Framework 4.0新增对TSQL的支持
    .Net 4.0中支持的更加完善了协变和逆变
    c#4.0——泛型委托的协变、逆变
    JQuery 常用方法基础教程
    AspNetPager分页示例
    微软一站式示例代码浏览器
    UI与实体的映射
  • 原文地址:https://www.cnblogs.com/arli/p/11428236.html
Copyright © 2011-2022 走看看