zoukankan      html  css  js  c++  java
  • 机器学习之一元线性回归

    概述

    线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法,运用十分广泛。其表达形式为y = w'x+e,e为误差服从均值为0的正态分布。
    回归分析中,只包括一个自变量和一个因变量,且二者的关系可用一条直线近似表示,这种回归分析称为一元线性回归分析。
    如果回归分析中包括两个或两个以上的自变量,且因变量和自变量之间是线性关系,则称为多元线性回归分析。

    线性回归的目的

    1、用于“预测”目标值。比如根据工资预测可贷款额度,根据商场人流预测销售额,根据河流水深预测降雨量等。
    2、用于变量“分析”。比如喝可乐对于体重的影响,汽车速度对于油耗的影响等。

    一元线性回归

    定义

    回归分析只涉及到两个变量的,称一元回归分析。一元回归的主要任务是从两个相关变量中的一个变量去估计另一个变量,被估计的变量,称因变量,可设为Y;估计出的变量,称自变量,设为X。回归分析就是要找出一个数学模型Y=f(X),使得从X估计Y可以用一个函数式去计算。当Y=f(X)的形式是一个直线方程时,称为一元线性回归。这个方程一般可表示为Y=A+BX。根据最小平方法或其他方法,可以从样本数据确定常数项A与回归系数B的值。A、B确定后,有一个X的观测值,就可得到一个Y的估计值。回归方程是否可靠,估计的误差有多大,都还应经过显著性检验和误差计算。有无显著的相关关系以及样本的大小等等,是影响回归方程可靠性的因素。

    实例讲解

    背景

    王经理是一家汽车销售公司的销售负责人,他深知投放互联网视频广告对于销售收入有提振作用,经过一年的实践,得到了2017年每个月的广告投入和销售额的数据。数据如下表所示:

    月份 广告投入(万) 销售额(万)
    1 20 659
    2 22 867
    3 19 630
    4 25 940
    5 18 600
    6 27 1000
    7 30 1170
    8 15 590
    9 33 1280
    10 38 1390
    11 29 1080
    12 40 1500
    import matplotlib.pyplot as plt
    x=[20,22,19,25,18,27,30,15,33,38,22,40]
    y=[659,867,630,940,600,1000,1170,590,1280,1390,1080,1500]
    # plt.plot(x,y)#画连线图
    plt.scatter(x,y)#画散点图
    plt.show()
    

    png

    from __future__ import print_function, division
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    #原始数据
    x_data=np.array([20,22,19,25,18,27,30,15,33,38,22,40])
    y_data=np.array([659,867,630,940,600,1000,1170,590,1280,1390,1080,1500])
    # 学习率
    learning_rate = 0.5
    # 迭代次数
    training_epochs = 1000
    # 定义运算时的占位符
    X = tf.placeholder(tf.float32)
    Y = tf.placeholder(tf.float32)
    # 定义模型参数
    W = tf.Variable(np.random.randn(), name="weight", dtype=tf.float32)
    b = tf.Variable(np.random.randn(), name="bias", dtype=tf.float32)
    # 定义模型
    pred = tf.add(tf.multiply(W, X), b)
    # 定义损失函数
    cost = tf.reduce_min(tf.pow(pred-Y, 2)/(2*100))
    # 使用Adam算法
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    # 初始化所有变量
    init = tf.global_variables_initializer()
    # 训练开始
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(training_epochs):
            for (x, y) in zip(x_data, y_data):
                sess.run(optimizer, feed_dict={X: x, Y: y})
            if (epoch + 1) % 50 == 0:#每50步输出一次结果
                c = sess.run(cost, feed_dict={X: x_data, Y: y_data})
                print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.3f}".format(c), "W=", sess.run(W), "b=", sess.run(b))
        print("Optimization Finished!")
        training_cost = sess.run(cost, feed_dict={X: x_data, Y: y_data})
        print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '
    ')
    
        # 使用matplot绘图
        plt.plot(x_data, y_data, 'ro', label="Original data")
        plt.plot(x_data, sess.run(W) * x_data + sess.run(b), label="Fitted line")
        plt.legend()
        plt.show()
    
    Epoch: 0050 cost= 0.003 W= 36.60496 b= 36.615093
    Epoch: 0100 cost= 0.066 W= 36.804733 b= 31.451597
    Epoch: 0150 cost= 0.223 W= 37.028324 b= 25.547626
    Epoch: 0200 cost= 0.456 W= 37.24408 b= 19.784721
    Epoch: 0250 cost= 0.552 W= 37.4337 b= 14.668744
    Epoch: 0300 cost= 0.516 W= 37.590607 b= 10.394382
    Epoch: 0350 cost= 0.485 W= 37.715668 b= 6.9576025
    Epoch: 0400 cost= 0.460 W= 37.81316 b= 4.2589083
    Epoch: 0450 cost= 0.439 W= 37.888195 b= 2.1701117
    Epoch: 0500 cost= 0.424 W= 37.945534 b= 0.56729364
    Epoch: 0550 cost= 0.412 W= 37.989162 b= -0.6561921
    Epoch: 0600 cost= 0.402 W= 38.022297 b= -1.5871764
    Epoch: 0650 cost= 0.395 W= 38.047413 b= -2.2941654
    Epoch: 0700 cost= 0.390 W= 38.066437 b= -2.830456
    Epoch: 0750 cost= 0.386 W= 38.080845 b= -3.2369432
    Epoch: 0800 cost= 0.383 W= 38.09176 b= -3.544872
    Epoch: 0850 cost= 0.380 W= 38.100018 b= -3.7780528
    Epoch: 0900 cost= 0.379 W= 38.106274 b= -3.9546213
    Epoch: 0950 cost= 0.377 W= 38.110996 b= -4.088267
    Epoch: 1000 cost= 0.376 W= 38.114582 b= -4.189459
    Optimization Finished!
    Training cost= 0.37628764 W= 38.114582 b= -4.189459 
    

    png

    由此就可以推算出投入和产出比例。

    由于文章是由jupyter写成,博客园的markdown无法查看图片,建议移步至原文查看

  • 相关阅读:
    数据结构(2)-链表
    数据结构(1)-数组
    SpringMVC学习总结(一)--Hello World入门
    基本数据类型对象的包装类
    关于String的相关常见方法
    常见的集合容器应当避免的坑
    再一次生产 CPU 高负载排查实践
    分表后需要注意的二三事
    线程池没你想的那么简单(续)
    线程池没你想的那么简单
  • 原文地址:https://www.cnblogs.com/zhouXX/p/9967104.html
Copyright © 2011-2022 走看看