zoukankan      html  css  js  c++  java
  • tensorflow开发基本步骤

    Tensorflow开发的基本步骤:

    • 定义Tensorflow输入节点
    1. 通过占位符定义:
      X = tf.placeholder("float")

      2.通过字典类型定义:

    inputdict = {
        'x': tf.placeholder("float"),
        'y': tf.placeholder("float")
    }

      3. 直接定义输入节点:

    train_x = np.float32(np.linspace(-1,1,100))
    • 定义“学习参数”的变量
    • 定义“运算”
    • 优化函数,优化目标
    • 初始化所有变量
    • 迭代更新参数到最优解
    • 测试模型
    • 使用模型

    2、模型保存与载入

    • 模型保存:
    saver = tf.train.Saver()  #生成saver
    saverdir = "log/"
    with tf.Session() as sess:
        sess.run(init)
        print("Finished")
        saver.save(sess,saverdir+"linermodel.cpkt")
    • 模型载入:
    with tf.Session() as sess2:
        sess2.run(tf.global_variables_initializer())
        saver.restore(sess2,saverdir+"linermodel.cpkt")
        print("x=0.2,z=",sess2.run(z,feed_dict={X:0.2}))

    检查点(Checkpoint):Tensorflow训练模型时难免会出现中断的情况,希望能够将辛苦得到的中间参数保留下来,在训练中保存模型,习惯上称之为保存检查点。

     saver = tf.train.Saver(max_to_keep=1)  #生成saver
     saver.restore(sess2,saverdir+"linermodel.cpkt-"+str(load_epoch))
  • 相关阅读:
    C++中类模板的概念和意义
    欢迎访问新博客aiyoupass.com
    P2327
    P2885
    P1968
    Link-Cut-Tree
    树的重心
    点分治笔记
    SPOJ 375
    树链剖分
  • 原文地址:https://www.cnblogs.com/wyx501/p/10541524.html
Copyright © 2011-2022 走看看