zoukankan      html  css  js  c++  java
  • 基于tensorflow的简单线性回归模型

    #!/usr/local/bin/python3

    ##ljj [1]

    ##linear regression model

    import tensorflow as tf

    import matplotlib.pyplot as plt

    #训练样本,随手写的

    x_ = [11,14,22,29,32,40,44,55,59,60,69,77]

    y_res = [123,135,155,167,177,189,200,240,250,255,277,298]

    #初始化定义w和b,都为1,这里折腾了一会,主要因为tf.ones的参数

    w = tf.Variable(tf.ones([1]),dtype="float32")

    b = tf.Variable(tf.ones([1]),dtype="float32")

    y = tf.placeholder(tf.float32)

    x = tf.placeholder(tf.float32)

    with tf.Session() as sess:

    #定义线性模型

        y_predict = w*x+b

    #平方误差作为损失函数

        loss = tf.reduce_mean(tf.square(y-y_predict))

    #配置训练优化器和学习速率

        train = tf.train.AdamOptimizer(0.03).minimize(loss)

      

        sess.run(tf.global_variables_initializer())

       

     for j in range(1000): 

         for i in range(len(x_)):

             # train.run(feed_dict={x:x_[i], y:y_res[i]})

        #feed训练,并输出w和b

              w_,b_,_= sess.run([w,b,train],feed_dict={x:x_[i], y:y_res[i]})

         print(w_,b_)


    print('final result : ')
    print(w_,b_)

    plt.plot(x_,y_res,'.')

    plt.plot(x_,x_*w_+b_,'-')

    plt.show()

    主机环境:MacbookPro,tensoflow版本1.4,pyhton3.5

    输出结果:

    final result : 

    [ 2.65540743] [ 91.92604065]

    -------以上输出分别是拟合出的Weight,Bias值。不同版本的tensorlfow,拟合的线可能会略有差异,稍微调调参就可以拟合的不错。

  • 相关阅读:
    ThinkPHP函数详解:C方法
    ThinkPHP函数详解:A方法
    php中的中文字符串长度计算以及截取
    JQ $("#form1 :input" ).length 与 $("#form1input").length有什么区别?
    php中的isset和empty的区别与认识
    谈谈ACM带来的一些东西
    HDU 4374--F(x)
    奖学金
    数字排序
    查找数字
  • 原文地址:https://www.cnblogs.com/lingjiajun/p/8964889.html
Copyright © 2011-2022 走看看