用Tensorflow搭建神经网络的一般步骤如下:
① 导入模块
② 创建模型变量和占位符
③ 建立模型
④ 定义loss函数
⑤ 定义优化器(optimizer), 使 loss 达到最小
⑥ 引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)
⑦ 训练模型
⑧ 检验模型
⑨ 使用模型预测数据
⑩ 保存模型
⑪ 使用Tensorboard的可视化功能
下面以一个简单的线性回归问题为例:
首先是训练模型的代码: train_model.py
1 # ① 导入模块 2 import tensorflow as tf 3 4 # ② 创建模型的变量和占位符 5 W = tf.Variable([.3], dtype=tf.float32) 6 b = tf.Variable([-.3], dtype=tf.float32) 7 x = tf.placeholder(tf.float32, name="input_x") 8 y = tf.placeholder(tf.float32, name="input_y") 9 10 # ③建立模型 11 linear_model = W*x + b 12 # 如果是矩阵相乘,可以写成: 13 # linear_model = tf.matmul(x, W)+b # matmul表示矩阵相乘 14 15 # ④ 定义loss函数 16 loss = tf.reduce_sum(tf.square(linear_model - y)) 17 18 # ⑤ 定义优化器(optimizer), 使 loss 达到最小 19 learning_rate=0.01 20 optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate) 21 train = optimizer.minimize(loss) 22 23 # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤) 24 25 # ⑦ 训练模型 26 # 假设模型是y=2x+1 27 x_train = [1, 2, 3, 4] 28 y_train = [3, 5, 7, 9] 29 30 init = tf.global_variables_initializer() # 添加用于初始化变量的节点 31 sess = tf.Session() 32 sess.run(init) # 运行初始化操作 33 for step in range(1000): 34 sess.run(train, {x: x_train, y: y_train}) 35 36 ''' 37 第⑦步和第⑩步可以合并为: 38 for step in xrange(1000000): 39 sess.run(train, {x: x_train, y: y_train}) 40 if step % 1000 == 0: 41 saver.save(sess, 'my-model', global_step=step) 42 ''' 43 44 # ⑧ 检验模型 45 curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train}) 46 print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss)) 47 ''' 48 W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11 49 ''' 50 51 # ⑨ 使用模型预测数据 52 x_predict = [-1, 0, 1, 2] 53 predicted_values=sess.run(linear_model, feed_dict={x:x_predict}) 54 # 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict] 55 print("result:", predicted_values) 56 ''' 57 result: [-1.0000062 0.99999553 2.99999714 4.99999905] 58 ''' 59 60 # ⑩ 保存模型 61 tf.add_to_collection("predict_network", linear_model) 62 saver = tf.train.Saver() 63 saver_path=saver.save(sess, "save/model.ckpt") 64 65 # ⑪ 使用Tensorboard的可视化功能 66 # 定义保存日志的路径 67 path = "log" # 也可写成: path = "./log" 68 writer=tf.summary.FileWriter(path, sess.graph) 69 70 sess.close()
然后是载入模型的代码: restore_model.py
1 import tensorflow as tf 2 3 with tf.Session() as sess: 4 new_saver=tf.train.import_meta_graph("save/model.ckpt.meta") 5 new_saver.restore(sess,"save/model.ckpt") 6 # print(tf.get_collection("predict_network")) 7 restored_y=tf.get_collection("predict_network")[0] # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可 8 9 graph=tf.get_default_graph() 10 restored_x=graph.get_operation_by_name("input_x").outputs[0] 11 12 predict_data = [-2, 3, 4] 13 predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data}) 14 15 print("result:", predicted_result) # result: [-3.00000787 7.00000048 9.00000191]