zoukankan      html  css  js  c++  java
  • [Python]机器学习:Tensorflow实现线性回归

    源码

    #>  tutorial:https://www.cnblogs.com/xianhan/p/9090426.html
    
    # 步骤一:构建模型
    
    # 1.TensorFlow 中的线性模型
    ## 占位符(Placeholder):表示执行梯度下降时将实际数据值输入到模型中的一个入口点。例如房子面积  (x) 和房价 (y_)。
    x = tf.placeholder(tf.float32,[None,1]); # X占位一条 Nx1维的向量
    
    ## 变量:表示我们试图寻找的能够使成本函数降到最小的「good」值的变量,例如 W 和 b。
    W = tf.Variable(tf.zeros([1,1])); # tf.zeros([1,1]):生成 第1行含1个元素的【二维】数组:[[ 0.]]
    b = tf.Variable(tf.zeros([1]));   # tf.zeros([1])  : 生成 第1行含1个元素的【一维数组】:[0.]
    
    ## 然后 TensorFlow 中的线性模型 (y = W.x + b) 就是:
    y = tf.matmul(x,W)+b;
    
    # 2.TensorFlow 中的成本函数
    ## 与将数据点的实际房价 (y_) 输入模型类似,我们创建一个占位符。
    y_ = tf.placeholder(tf.float32,[None,1])
    
    ## 成本函数的最小方差就是:
    cost = tf.reduce_sum(tf.pow(y_ - y,2)); # 各项样本点的最小方差之和作为拟合的成本函数
    
    # 3.数据
    ## 由于没有房价(y_) 和房子面积 (x) 的实际数据点,我们就生成它们
    ## 简单起见,我们将房价 (ys) 设置成永远是房子面积 (xs) 的 2 倍。
    for i in range(100):
        ## create fake data for actual data
        xs = np.array([[i]]);
        ys = np.array([[2*i+20]]);
        pass;
    
    # 4.梯度下降
    ## 有了线性模型、成本函数和数据,我们就可以开始执行梯度下降从而最小化代价函数,以获得 W、b 的「good」值。
    learning_rate = 0.001; ## 学习率 or步长 (每次进行训练时在最陡的梯度方向上所采取的「步」长)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); 
    
    # 步骤二:训练模型
    ## 训练包含以预先确定好的次数执行梯度下降,或者是直到成本函数低于某个预先确定的临界值为止。
    
    # 1.TensorFlow 的怪异
    ## 所有变量都需要在训练开始时进行初始化,否则它们可能会带有之前执行过程中的残余值。
    init = tf.initialize_all_variables();
    
    # 2.TensorFlow 会话
    ## 虽然 TensorFlow 是一个 Python 库,Python 是一种解释性的语言,但是默认情况下不把 TensorFlow 运算用作解释性能的原因,因此不执行上面的 init 。
    ## 相反 TensorFlow 是在一个会话中进行;创建一个会话 (sess) 然后使用 sess.run() 去执行。
    session = tf.Session();
    session.run(init)
    
    steps = 50; # 迭代次数过高以后,会产生过拟合现象【其计算出的值可能会是严重错误的拟合值】
    # 类似地我们在一个循环中调用 withinsess.run() 来执行上面的 train_step
    arrayX = [];
    arrayY = [];
    for i in range(steps):
        # Create fake data for y = W*x + b where W=2,b=0.2
        xs = np.array([[i]]); 
        ys = np.array([[2*i+0.2]]);
    #     xs = np.array([x_data[i]]); 
    #     ys = np.array([y_true[i]]);
        
        arrayX.extend(xs[0]);
        arrayY.extend(ys[0]);
        
        # Train
        feed = {x:xs,y_:ys};
        session.run(train_step,feed_dict=feed); # feed them into train_step
        
        # View 
        print("After %d iteration:"%i)
        print("W:%f"%session.run(W))
        print("b:%f"%session.run(b))
        pass; 
    
    # 可视化
    print("W:
    ",session.run(W));
    print("b:
    ",session.run(b));
    arrayX = np.array(arrayX);
    arrayX = arrayX.reshape((1,steps));
    
    arrayB = np.array(np.full(steps,session.run(b)));
    arrayB = arrayB.reshape(1,steps);
    arrayB = np.transpose(arrayB)
    # print("arrayB:
    ",arrayB);
    predictYs = np.dot(np.transpose(arrayX),session.run(W))+ arrayB;
    
    # print(predictYs);
    # print(arrayX)
    #print(arrayY)
    
    plt.rcParams['figure.dpi'] = 300 #分辨率
    plt.scatter(arrayX, arrayY, marker = '*',color = 'red', s = 10 ,label = 'Actual Dataset')
    plt.scatter(arrayX, predictYs, marker = 'o',color = 'green', s = 8 ,label = 'Fit Dataset')
    plt.legend(loc = 'best')    # 设置 图例所在的位置 使用推荐位置
    
    After 0 iteration:
    W:0.000000
    b:0.000400
    After 1 iteration:
    W:0.004399
    b:0.004799
    After 2 iteration:
    W:0.021145
    b:0.013172
    After 3 iteration:
    W:0.057885
    b:0.025419
    After 4 iteration:
    W:0.121430
    b:0.041305
    After 5 iteration:
    W:0.216945
    b:0.060408
    After 6 iteration:
    W:0.347000
    b:0.082084
    After 7 iteration:
    W:0.510645
    b:0.105462
    After 8 iteration:
    W:0.702795
    b:0.129480
    After 9 iteration:
    W:0.914212
    b:0.152971
    After 10 iteration:
    W:1.132310
    b:0.174781
    After 11 iteration:
    W:1.342846
    b:0.193921
    After 12 iteration:
    W:1.532252
    b:0.209704
    After 13 iteration:
    W:1.690099
    b:0.221846
    After 14 iteration:
    W:1.810968
    b:0.230480
    After 15 iteration:
    W:1.895118
    b:0.236090
    After 16 iteration:
    W:1.947663
    b:0.239374
    After 17 iteration:
    W:1.976575
    b:0.241075
    After 18 iteration:
    W:1.990276
    b:0.241836
    After 19 iteration:
    W:1.995707
    b:0.242122
    After 20 iteration:
    W:1.997456
    b:0.242209
    After 21 iteration:
    W:1.997927
    b:0.242232
    After 22 iteration:
    W:1.998075
    b:0.242238
    After 23 iteration:
    W:1.998169
    b:0.242242
    After 24 iteration:
    W:1.998251
    b:0.242246
    After 25 iteration:
    W:1.998325
    b:0.242249
    After 26 iteration:
    W:1.998393
    b:0.242251
    After 27 iteration:
    W:1.998455
    b:0.242254
    After 28 iteration:
    W:1.998512
    b:0.242256
    After 29 iteration:
    W:1.998564
    b:0.242258
    After 30 iteration:
    W:1.998613
    b:0.242259
    After 31 iteration:
    W:1.998658
    b:0.242261
    After 32 iteration:
    W:1.998701
    b:0.242262
    After 33 iteration:
    W:1.998741
    b:0.242263
    After 34 iteration:
    W:1.998778
    b:0.242264
    After 35 iteration:
    W:1.998813
    b:0.242265
    After 36 iteration:
    W:1.998846
    b:0.242266
    After 37 iteration:
    W:1.998878
    b:0.242267
    After 38 iteration:
    W:1.998907
    b:0.242268
    After 39 iteration:
    W:1.998935
    b:0.242269
    After 40 iteration:
    W:1.998960
    b:0.242269
    After 41 iteration:
    W:1.998989
    b:0.242270
    After 42 iteration:
    W:1.999004
    b:0.242270
    After 43 iteration:
    W:1.999050
    b:0.242271
    After 44 iteration:
    W:1.999007
    b:0.242270
    After 45 iteration:
    W:1.999222
    b:0.242275
    After 46 iteration:
    W:1.998624
    b:0.242262
    After 47 iteration:
    W:2.000731
    b:0.242307
    After 48 iteration:
    W:1.993301
    b:0.242152
    After 49 iteration:
    W:2.021338
    b:0.242724
    
    W:
     [[ 2.02133822]]
    b:
     [ 0.24272442]
    

    推荐文献

  • 相关阅读:
    【Java EE 学习 81】【CXF框架】【CXF整合Spring】
    【Java EE 学习 80 下】【调用WebService服务的四种方式】【WebService中的注解】
    【Java EE 学习 80 上】【WebService】
    【Java EE 学习 79 下】【动态SQL】【mybatis和spring的整合】
    【Java EE 学习 79 上】【mybatis 基本使用方法】
    【Java EE 学习 78 下】【数据采集系统第十天】【数据采集系统完成】
    【Java EE 学习 78 中】【数据采集系统第十天】【Spring远程调用】
    【Java EE 学习 78 上】【数据采集系统第十天】【Service使用Spring缓存模块】
    【Java EE 学习 77 下】【数据采集系统第九天】【使用spring实现答案水平分库】【未解决问题:分库查询问题】
    【Java EE 学习 77 上】【数据采集系统第九天】【通过AOP实现日志管理】【通过Spring石英调度动态生成日志表】【日志分表和查询】
  • 原文地址:https://www.cnblogs.com/johnnyzen/p/10856815.html
Copyright © 2011-2022 走看看