zoukankan      html  css  js  c++  java
  • python线性拟合数据(深度学习笔记一)

    给定样本x数据:

    [1, 3, 5, 2], [2, 5, 4, 9], [1, 3, 7, 9]

    对应y数据:

    [27, 58, 67, 102]

    下面用python求线性曲线方程:
    import numpy as np
    
    # y=2x1+ 7x2+3x3+8 
    x = np.array([[1, 3, 5, 2], [2, 5, 4, 9], [1, 3, 7, 9]])
    y = np.array([[27, 58, 67, 102]])
    w = np.array([0, 0, 0]).reshape(1, 3)
    b = 0
    '''
    for i in range(4):
        print(x[0][i]*2+x[1][i]*7+x[2][i]*3+8)
    '''
    '''定义传播函数 由于不涉及到二分类等问题 所以不使用sigmoid函数  使用线性回归 取平方差的和  '''
    
    def process(w, b, X, Y):
        m = X.shape[1]
        A = np.dot(w, X) + b  # (1,4)
        assert (A.shape == (1, 4))
        cost = 1 / m * np.sum((A - Y) ** 2)  # 越小则越接近实际值
        dw = 2 / m * np.dot((A - Y), X.T)
        db = 2 / m * np.sum(A - Y)
        assert (dw.shape == w.shape)
        assert (cost.shape == ())
        dao = {'dw': dw, 'db': db}
        return dao, cost
    
    
    alpha = 0.01  # 学习率
    itemra_Num =10001  # 迭代次数
    for i in range(itemra_Num):
        dao, cost = process(w, b, x, y)
        dw = dao['dw']
        db = dao['db']
        w = w - alpha * dw
        b = b - alpha * db
        if i % 1000 == 0:
            print(w)
            print(b)
            print(cost)

    如下结果:

    [[ 3.7   7.65  7.94]]
    1.27
    4746.5
    [[ 2.30019751  7.26012923  2.81900947]]
    6.6797835089
    0.14539597224
    [[ 2.06694472  7.0580094   2.95963871]]
    7.70558877116
    0.00723054126807
    [[ 2.01492882  7.01293623  2.99099934]]
    7.93434563782
    0.000359574795807
    [[ 2.00332916  7.00288481  2.99799283]]
    7.98535893046
    1.78816535286e-05
    [[ 2.00074241  7.00064332  2.9995524 ]]
    7.99673500876
    8.89254576922e-07
    [[ 2.00016556  7.00014346  2.99990018]]
    7.99927189966
    4.42226274712e-08
    [[ 2.00003692  7.00003199  2.99997774]]
    7.999837632
    2.1991911329e-09
    [[ 2.00000823  7.00000713  2.99999504]]
    7.99996379157
    1.09365768502e-10
    [[ 2.00000184  7.00000159  2.99999889]]
    7.99999192544
    5.43875935097e-12
    [[ 2.00000041  7.00000035  2.99999975]]
    7.99999819935
    2.70469485079e-13

    从数据来看 成本函数一直在递减 说明 方向是正确的 整体上系数也越来越接近(2,7,3)+8

  • 相关阅读:
    。。
    6-4 静态内部类
    SQL把一个表里的数据赋值到另外一个表里去
    jquery 设置 disabled属性
    6-4 内部类
    DWR 整合之Struts2.3.16
    DWR整合之JSF
    DWR整合之Servlet
    dwr.xml 配置
    认识DWR
  • 原文地址:https://www.cnblogs.com/x0216u/p/7640830.html
Copyright © 2011-2022 走看看