zoukankan      html  css  js  c++  java
  • 线性回归&&code

     1 # -*- coding: utf-8 -*-
     2 
     3 import numpy as np
     4 import matplotlib.pyplot as plt
     5 from certifi import __main__
     6 
     7 def cost(x,y,theta=np.zeros((2,1))):
     8    m=len(y);
     9    J=1.0/(2*m)*sum((x.dot(theta).flatten()-y)**2);
    10    return J;
    11 
    12 def gradientDesc(x,y,theat=np.zeros((2,1)),alpha=0.001,iterations=1500):
    13     m=len(y)
    14     J=[]
    15     for i in xrange(iterations):
    16         a=theat[0][0]-alpha*(1.0/m)*sum((x.dot(theat).flatten()-y)*x[:,0]);
    17         b=theat[1][0]-alpha*(1.0/m)*sum((x.dot(theat).flatten()-y)*1);
    18         theat[0][0],theat[1][0]=a,b
    19         print theat[0][0], theat[1][0]
    20         print cost(x, y, theat);
    21         
    22     return theat;
    23 
    24 if __name__=="__main__":
    25     x=np.array([[9,1],[15,1],[25,1],[14,1],[10,1],[18,1]]);
    26     y=np.array([39,56,93,61,50,75]);
    27     ans=gradientDesc(x, y);
    28     xx=[1,30]
    29     yy=[ans[0][0]*1+ans[1][0],ans[0][0]*30+ans[1][0]]
    30     plt.plot(xx,yy)
    31     plt.scatter(x[:,0],y)
    32     plt.show()
    33     
    34     print 'end'
    35 
    36 #显示数据
    37 '''
    38 plt.scatter(x,y);
    39 plt.show();
    40 '''

    结果显示

  • 相关阅读:
    C++内存管理
    GitHub 简单用法
    Tembin
    git
    js 插件使用总结
    cas sso
    Redis实战
    全面分析 Spring 的编程式事务管理及声明式事务管理
    mybatis
    b2b
  • 原文地址:https://www.cnblogs.com/wuxiangli/p/5879514.html
Copyright © 2011-2022 走看看