简单例子介绍Tensorflow实现机器学习的思路,重点步骤:
- 生成人工数据集及其可视化
- 构建线性模型
- 定义损失函数
- 定义优化器、最小损失函数
- 训练结果的可视化
- 利用学习到的模型进行预测
1 import tensorflow as tf 2 import numpy as np 3 import matplotlib.pyplot as plt 4 5 np.random.seed(5) 6 # 采用np生成等差数列,范围在-1~1之间生成100个点 7 x_data=np.linspace(-1,1,100) 8 # y=2x+1+噪声,其中噪声的维度和x_data一致 9 y_data=2*x_data+1.0+np.random.randn(*x_data.shape)*0.4 10 11 # 画图 y=2x+1 12 plt.scatter(x_data,y_data) 13 plt.plot(x_data,2*x_data+1,color='red',linewidth=3) 14 # plt.show() 15 16 # 定义训练数据的占位符,x是特征值,y是标签值 17 x=tf.placeholder("float",name="x") 18 y=tf.placeholder("float",name="y") 19 20 # 定义模型函数 21 def model(x,w,b): 22 return tf.multiply(x,w)+b 23 24 # tf.Variable的作用时保存和更新函数 25 w=tf.Variable(1.0) #斜率 26 b=tf.Variable(0.0) #截距 27 pred=model(x,w,b) #预测值 28 29 # 迭代次数和学习率 30 train_epochs=10 31 learning_rate=0.05 32 33 # 损失函数用来描述预测值与真实值之间的误差,从而指导模型收敛方向,均方差MSE 34 loss=tf.reduce_mean(tf.square(y-pred)) 35 # 定义优化器optimizer,初始化一个GradientDescentOptimizer,设置学习率和优化目标,最小值损失 36 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) 37 38 sess=tf.Session() 39 init=tf.global_variables_initializer() 40 sess.run(init) 41 42 # 模型训练阶段,设置迭代轮次,每次通过样本逐个输入模型,进行梯度下降优化操作 43 # 每次迭代,绘制出模型曲线 44 45 # 开始训练,轮次为epoch,采用SGD随机梯度下降优化方法 46 step=0 #记录训练步数 47 loss_list=[] #用于保存loss值的列表 48 49 for epoch in range(train_epochs): 50 for xs,ys in zip(x_data,y_data): 51 _,losss=sess.run([optimizer,loss],feed_dict={x:xs,y:ys}) 52 # 显示损失值loss 53 # display_step:控制报告的粒度 54 # 例如,如果display_step设为2,则将每训练2个样本输出一次损失值 55 # 与超参数不同,修改display_step 不会更改模型所学习的规律 56 loss_list.append(losss) 57 step=step+1 58 display_step=10 59 if step%display_step==0: 60 print("Train Epoch:",'%02d'%(epoch+1),"Step:%03d"%(step),"loss=","{:.9f}".format(losss)) 61 b0temp=b.eval(session=sess) 62 w0temp=w.eval(session=sess) 63 plt.plot(x_data,w0temp*x_data+b0temp) 64 plt.plot(loss_list,'r+') 65 plt.show() 66 67 # 训练完成后,打印查看参数 68 print("w:",sess.run(w)) 69 print("b",sess.run(b)) 70 71 plt.scatter(x_data,y_data,label='Original data') 72 plt.plot(x_data,x_data*sess.run(w)+sess.run(b),label='Fitted line',color='r',linewidth=3) 73 plt.legend(loc=2) #通过参数loc指定图例位置 74 # plt.show() 75 76 x_test=3.21 77 predict=sess.run(pred,feed_dict={x:x_test}) 78 print("预测值:%f"%predict) 79 80 target=2*x_test+1.0 81 print("目标值%f"%target)