zoukankan      html  css  js  c++  java
  • 用Tensorflow搭建神经网络的一般步骤

    用Tensorflow搭建神经网络的一般步骤如下:

    导入模块

    创建模型变量和占位符

    建立模型

    定义loss函数

    定义优化器(optimizer), 使 loss 达到最小

    引入激活函数, 即添加非线性因素 (线性回归问题跳过此步骤)

    训练模型

    检验模型

    使用模型预测数据

    保存模型

    使用Tensorboard的可视化功能

    下面以一个简单的线性回归问题为例:

    首先是训练模型的代码: train_model.py

     1 # ① 导入模块
     2 import tensorflow as tf
     3 
     4 # ② 创建模型的变量和占位符
     5 W = tf.Variable([.3], dtype=tf.float32)
     6 b = tf.Variable([-.3], dtype=tf.float32)
     7 x = tf.placeholder(tf.float32, name="input_x")
     8 y = tf.placeholder(tf.float32, name="input_y")
     9 
    10 # ③建立模型
    11 linear_model = W*x + b
    12 # 如果是矩阵相乘,可以写成:
    13 # linear_model = tf.matmul(x, W)+b  # matmul表示矩阵相乘
    14 
    15 # ④ 定义loss函数
    16 loss = tf.reduce_sum(tf.square(linear_model - y))
    17 
    18 # ⑤ 定义优化器(optimizer), 使 loss 达到最小
    19 learning_rate=0.01
    20 optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    21 train = optimizer.minimize(loss)
    22 
    23 # ⑥ 引入激活函数, 即添加非线性因素。(线性回归问题跳过此步骤)
    24 
    25 # ⑦ 训练模型
    26 # 假设模型是y=2x+1
    27 x_train = [1, 2, 3, 4]
    28 y_train = [3, 5, 7, 9]
    29 
    30 init = tf.global_variables_initializer() # 添加用于初始化变量的节点
    31 sess = tf.Session()
    32 sess.run(init) # 运行初始化操作
    33 for step in range(1000):
    34    sess.run(train, {x: x_train, y: y_train})
    35 
    36 '''
    37 第⑦步和第⑩步可以合并为:
    38 for step in xrange(1000000):
    39     sess.run(train, {x: x_train, y: y_train})
    40     if step % 1000 == 0:
    41         saver.save(sess, 'my-model', global_step=step)
    42 '''
    43 
    44 # ⑧ 检验模型
    45 curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
    46 print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
    47 '''
    48 W: [ 2.00000167] b: [ 0.99999553] loss: 1.29603e-11
    49 '''
    50 
    51 # ⑨ 使用模型预测数据
    52 x_predict = [-1, 0, 1, 2]
    53 predicted_values=sess.run(linear_model, feed_dict={x:x_predict})
    54 # 注意这么一种写法: predicted_values = [(W*x + b).eval(session=sess) for x in x_predict]
    55 print("result:", predicted_values)
    56 '''
    57 result: [-1.0000062   0.99999553  2.99999714  4.99999905]
    58 '''
    59 
    60 # ⑩ 保存模型
    61 tf.add_to_collection("predict_network", linear_model)
    62 saver = tf.train.Saver()
    63 saver_path=saver.save(sess, "save/model.ckpt")
    64 
    65 # ⑪ 使用Tensorboard的可视化功能
    66 # 定义保存日志的路径
    67 path = "log"  # 也可写成: path = "./log"
    68 writer=tf.summary.FileWriter(path, sess.graph)
    69 
    70 sess.close()

    然后是载入模型的代码: restore_model.py

     1 import tensorflow as tf
     2 
     3 with tf.Session() as sess:
     4     new_saver=tf.train.import_meta_graph("save/model.ckpt.meta")
     5     new_saver.restore(sess,"save/model.ckpt")
     6     # print(tf.get_collection("predict_network"))
     7     restored_y=tf.get_collection("predict_network")[0]  # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
     8 
     9     graph=tf.get_default_graph()
    10     restored_x=graph.get_operation_by_name("input_x").outputs[0]
    11 
    12     predict_data = [-2, 3, 4]
    13     predicted_result = sess.run(restored_y, feed_dict={restored_x:predict_data})
    14 
    15     print("result:", predicted_result)  # result: [-3.00000787  7.00000048  9.00000191]
  • 相关阅读:
    支付宝H5、APP支付服务端的区别(php)
    微信小程序快速转成百度小程序的方法
    pm2命令管理启动的nodejs项目进程
    CentOS7 宝塔搭配git 实时更新项目源码
    CentOS7 搭建GIT环境
    Json数据交互
    HTML标签大全
    Java集合
    JavaSE基础知识
    idea2019版本及以下全家桶永久破解
  • 原文地址:https://www.cnblogs.com/ArrozZhu/p/8407075.html
Copyright © 2011-2022 走看看