zoukankan      html  css  js  c++  java
  • 单变量线性回归问题(TensorFlow实战)

    简单例子介绍Tensorflow实现机器学习的思路,重点步骤:

    1. 生成人工数据集及其可视化
    2. 构建线性模型
    3. 定义损失函数
    4. 定义优化器、最小损失函数
    5. 训练结果的可视化
    6. 利用学习到的模型进行预测
     1 import tensorflow as tf
     2 import numpy as np
     3 import matplotlib.pyplot as plt
     4 
     5 np.random.seed(5)
     6 # 采用np生成等差数列,范围在-1~1之间生成100个点
     7 x_data=np.linspace(-1,1,100)
     8 # y=2x+1+噪声,其中噪声的维度和x_data一致
     9 y_data=2*x_data+1.0+np.random.randn(*x_data.shape)*0.4
    10 
    11 # 画图 y=2x+1
    12 plt.scatter(x_data,y_data)
    13 plt.plot(x_data,2*x_data+1,color='red',linewidth=3)
    14 # plt.show()
    15 
    16 # 定义训练数据的占位符,x是特征值,y是标签值
    17 x=tf.placeholder("float",name="x")
    18 y=tf.placeholder("float",name="y")
    19 
    20 # 定义模型函数
    21 def model(x,w,b):
    22     return tf.multiply(x,w)+b
    23 
    24 # tf.Variable的作用时保存和更新函数
    25 w=tf.Variable(1.0)      #斜率
    26 b=tf.Variable(0.0)      #截距
    27 pred=model(x,w,b)       #预测值
    28 
    29 # 迭代次数和学习率
    30 train_epochs=10
    31 learning_rate=0.05
    32 
    33 # 损失函数用来描述预测值与真实值之间的误差,从而指导模型收敛方向,均方差MSE
    34 loss=tf.reduce_mean(tf.square(y-pred))
    35 # 定义优化器optimizer,初始化一个GradientDescentOptimizer,设置学习率和优化目标,最小值损失
    36 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    37 
    38 sess=tf.Session()
    39 init=tf.global_variables_initializer()
    40 sess.run(init)
    41 
    42 # 模型训练阶段,设置迭代轮次,每次通过样本逐个输入模型,进行梯度下降优化操作
    43 # 每次迭代,绘制出模型曲线
    44 
    45 # 开始训练,轮次为epoch,采用SGD随机梯度下降优化方法
    46 step=0  #记录训练步数
    47 loss_list=[]    #用于保存loss值的列表
    48 
    49 for epoch in range(train_epochs):
    50     for xs,ys in zip(x_data,y_data):
    51         _,losss=sess.run([optimizer,loss],feed_dict={x:xs,y:ys})
    52     # 显示损失值loss
    53     # display_step:控制报告的粒度
    54     # 例如,如果display_step设为2,则将每训练2个样本输出一次损失值
    55     # 与超参数不同,修改display_step 不会更改模型所学习的规律
    56         loss_list.append(losss)
    57         step=step+1
    58         display_step=10
    59         if step%display_step==0:
    60             print("Train Epoch:",'%02d'%(epoch+1),"Step:%03d"%(step),"loss=","{:.9f}".format(losss))
    61     b0temp=b.eval(session=sess)
    62     w0temp=w.eval(session=sess)
    63     plt.plot(x_data,w0temp*x_data+b0temp)
    64 plt.plot(loss_list,'r+')
    65 plt.show()
    66 
    67 # 训练完成后,打印查看参数
    68 print("w:",sess.run(w))
    69 print("b",sess.run(b))
    70 
    71 plt.scatter(x_data,y_data,label='Original data')
    72 plt.plot(x_data,x_data*sess.run(w)+sess.run(b),label='Fitted line',color='r',linewidth=3)
    73 plt.legend(loc=2)   #通过参数loc指定图例位置
    74 # plt.show()
    75 
    76 x_test=3.21
    77 predict=sess.run(pred,feed_dict={x:x_test})
    78 print("预测值:%f"%predict)
    79 
    80 target=2*x_test+1.0
    81 print("目标值%f"%target)
  • 相关阅读:
    python批量裁剪图片
    Theano 报错:No suitable SharedVariable constructor could be found. Are you sure all kwargs are supported? We do not support the parameter dtype or type
    清华镜像连接
    ubuntu16.04查看占用GPU的程序
    pycharm报错:ImportError: libcusolver.so.8.0: cannot open shared object file: No such file or directory
    PyMysql的基本操作
    关于爬虫解析页面时的一些有意思的坑
    关于爬虫解析页面时的一些有意思的坑
    python 的一些高级函数
    python 的一些高级函数
  • 原文地址:https://www.cnblogs.com/hly97/p/12815086.html
Copyright © 2011-2022 走看看