zoukankan      html  css  js  c++  java
  • [Python]机器学习:Tensorflow实现线性回归

    源码

    #>  tutorial:https://www.cnblogs.com/xianhan/p/9090426.html
    
    # 步骤一:构建模型
    
    # 1.TensorFlow 中的线性模型
    ## 占位符(Placeholder):表示执行梯度下降时将实际数据值输入到模型中的一个入口点。例如房子面积  (x) 和房价 (y_)。
    x = tf.placeholder(tf.float32,[None,1]); # X占位一条 Nx1维的向量
    
    ## 变量:表示我们试图寻找的能够使成本函数降到最小的「good」值的变量,例如 W 和 b。
    W = tf.Variable(tf.zeros([1,1])); # tf.zeros([1,1]):生成 第1行含1个元素的【二维】数组:[[ 0.]]
    b = tf.Variable(tf.zeros([1]));   # tf.zeros([1])  : 生成 第1行含1个元素的【一维数组】:[0.]
    
    ## 然后 TensorFlow 中的线性模型 (y = W.x + b) 就是:
    y = tf.matmul(x,W)+b;
    
    # 2.TensorFlow 中的成本函数
    ## 与将数据点的实际房价 (y_) 输入模型类似,我们创建一个占位符。
    y_ = tf.placeholder(tf.float32,[None,1])
    
    ## 成本函数的最小方差就是:
    cost = tf.reduce_sum(tf.pow(y_ - y,2)); # 各项样本点的最小方差之和作为拟合的成本函数
    
    # 3.数据
    ## 由于没有房价(y_) 和房子面积 (x) 的实际数据点,我们就生成它们
    ## 简单起见,我们将房价 (ys) 设置成永远是房子面积 (xs) 的 2 倍。
    for i in range(100):
        ## create fake data for actual data
        xs = np.array([[i]]);
        ys = np.array([[2*i+20]]);
        pass;
    
    # 4.梯度下降
    ## 有了线性模型、成本函数和数据,我们就可以开始执行梯度下降从而最小化代价函数,以获得 W、b 的「good」值。
    learning_rate = 0.001; ## 学习率 or步长 (每次进行训练时在最陡的梯度方向上所采取的「步」长)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); 
    
    # 步骤二:训练模型
    ## 训练包含以预先确定好的次数执行梯度下降,或者是直到成本函数低于某个预先确定的临界值为止。
    
    # 1.TensorFlow 的怪异
    ## 所有变量都需要在训练开始时进行初始化,否则它们可能会带有之前执行过程中的残余值。
    init = tf.initialize_all_variables();
    
    # 2.TensorFlow 会话
    ## 虽然 TensorFlow 是一个 Python 库,Python 是一种解释性的语言,但是默认情况下不把 TensorFlow 运算用作解释性能的原因,因此不执行上面的 init 。
    ## 相反 TensorFlow 是在一个会话中进行;创建一个会话 (sess) 然后使用 sess.run() 去执行。
    session = tf.Session();
    session.run(init)
    
    steps = 50; # 迭代次数过高以后,会产生过拟合现象【其计算出的值可能会是严重错误的拟合值】
    # 类似地我们在一个循环中调用 withinsess.run() 来执行上面的 train_step
    arrayX = [];
    arrayY = [];
    for i in range(steps):
        # Create fake data for y = W*x + b where W=2,b=0.2
        xs = np.array([[i]]); 
        ys = np.array([[2*i+0.2]]);
    #     xs = np.array([x_data[i]]); 
    #     ys = np.array([y_true[i]]);
        
        arrayX.extend(xs[0]);
        arrayY.extend(ys[0]);
        
        # Train
        feed = {x:xs,y_:ys};
        session.run(train_step,feed_dict=feed); # feed them into train_step
        
        # View 
        print("After %d iteration:"%i)
        print("W:%f"%session.run(W))
        print("b:%f"%session.run(b))
        pass; 
    
    # 可视化
    print("W:
    ",session.run(W));
    print("b:
    ",session.run(b));
    arrayX = np.array(arrayX);
    arrayX = arrayX.reshape((1,steps));
    
    arrayB = np.array(np.full(steps,session.run(b)));
    arrayB = arrayB.reshape(1,steps);
    arrayB = np.transpose(arrayB)
    # print("arrayB:
    ",arrayB);
    predictYs = np.dot(np.transpose(arrayX),session.run(W))+ arrayB;
    
    # print(predictYs);
    # print(arrayX)
    #print(arrayY)
    
    plt.rcParams['figure.dpi'] = 300 #分辨率
    plt.scatter(arrayX, arrayY, marker = '*',color = 'red', s = 10 ,label = 'Actual Dataset')
    plt.scatter(arrayX, predictYs, marker = 'o',color = 'green', s = 8 ,label = 'Fit Dataset')
    plt.legend(loc = 'best')    # 设置 图例所在的位置 使用推荐位置
    
    After 0 iteration:
    W:0.000000
    b:0.000400
    After 1 iteration:
    W:0.004399
    b:0.004799
    After 2 iteration:
    W:0.021145
    b:0.013172
    After 3 iteration:
    W:0.057885
    b:0.025419
    After 4 iteration:
    W:0.121430
    b:0.041305
    After 5 iteration:
    W:0.216945
    b:0.060408
    After 6 iteration:
    W:0.347000
    b:0.082084
    After 7 iteration:
    W:0.510645
    b:0.105462
    After 8 iteration:
    W:0.702795
    b:0.129480
    After 9 iteration:
    W:0.914212
    b:0.152971
    After 10 iteration:
    W:1.132310
    b:0.174781
    After 11 iteration:
    W:1.342846
    b:0.193921
    After 12 iteration:
    W:1.532252
    b:0.209704
    After 13 iteration:
    W:1.690099
    b:0.221846
    After 14 iteration:
    W:1.810968
    b:0.230480
    After 15 iteration:
    W:1.895118
    b:0.236090
    After 16 iteration:
    W:1.947663
    b:0.239374
    After 17 iteration:
    W:1.976575
    b:0.241075
    After 18 iteration:
    W:1.990276
    b:0.241836
    After 19 iteration:
    W:1.995707
    b:0.242122
    After 20 iteration:
    W:1.997456
    b:0.242209
    After 21 iteration:
    W:1.997927
    b:0.242232
    After 22 iteration:
    W:1.998075
    b:0.242238
    After 23 iteration:
    W:1.998169
    b:0.242242
    After 24 iteration:
    W:1.998251
    b:0.242246
    After 25 iteration:
    W:1.998325
    b:0.242249
    After 26 iteration:
    W:1.998393
    b:0.242251
    After 27 iteration:
    W:1.998455
    b:0.242254
    After 28 iteration:
    W:1.998512
    b:0.242256
    After 29 iteration:
    W:1.998564
    b:0.242258
    After 30 iteration:
    W:1.998613
    b:0.242259
    After 31 iteration:
    W:1.998658
    b:0.242261
    After 32 iteration:
    W:1.998701
    b:0.242262
    After 33 iteration:
    W:1.998741
    b:0.242263
    After 34 iteration:
    W:1.998778
    b:0.242264
    After 35 iteration:
    W:1.998813
    b:0.242265
    After 36 iteration:
    W:1.998846
    b:0.242266
    After 37 iteration:
    W:1.998878
    b:0.242267
    After 38 iteration:
    W:1.998907
    b:0.242268
    After 39 iteration:
    W:1.998935
    b:0.242269
    After 40 iteration:
    W:1.998960
    b:0.242269
    After 41 iteration:
    W:1.998989
    b:0.242270
    After 42 iteration:
    W:1.999004
    b:0.242270
    After 43 iteration:
    W:1.999050
    b:0.242271
    After 44 iteration:
    W:1.999007
    b:0.242270
    After 45 iteration:
    W:1.999222
    b:0.242275
    After 46 iteration:
    W:1.998624
    b:0.242262
    After 47 iteration:
    W:2.000731
    b:0.242307
    After 48 iteration:
    W:1.993301
    b:0.242152
    After 49 iteration:
    W:2.021338
    b:0.242724
    
    W:
     [[ 2.02133822]]
    b:
     [ 0.24272442]
    

    推荐文献

  • 相关阅读:
    aaa
    记一次Vue实战总结
    Data too long for column 'xxx' at row 1MySql.Data.MySqlClient.MySqlPacket ReadPacket() 报错解决
    uni-app 监听返回按钮
    微信H5分享外部链接,缩略图不显示
    uni-app 动态控制下拉刷新
    vueX 的使用
    uni-app H5 腾讯地图无法导航
    uni-app支付功能
    hooks 与 animejs
  • 原文地址:https://www.cnblogs.com/johnnyzen/p/10856815.html
Copyright © 2011-2022 走看看