zoukankan      html  css  js  c++  java
  • tensorflow开发基本步骤

    Tensorflow开发的基本步骤:

    • 定义Tensorflow输入节点
    1. 通过占位符定义:
      X = tf.placeholder("float")

      2.通过字典类型定义:

    inputdict = {
        'x': tf.placeholder("float"),
        'y': tf.placeholder("float")
    }

      3. 直接定义输入节点:

    train_x = np.float32(np.linspace(-1,1,100))
    • 定义“学习参数”的变量
    • 定义“运算”
    • 优化函数,优化目标
    • 初始化所有变量
    • 迭代更新参数到最优解
    • 测试模型
    • 使用模型

    2、模型保存与载入

    • 模型保存:
    saver = tf.train.Saver()  #生成saver
    saverdir = "log/"
    with tf.Session() as sess:
        sess.run(init)
        print("Finished")
        saver.save(sess,saverdir+"linermodel.cpkt")
    • 模型载入:
    with tf.Session() as sess2:
        sess2.run(tf.global_variables_initializer())
        saver.restore(sess2,saverdir+"linermodel.cpkt")
        print("x=0.2,z=",sess2.run(z,feed_dict={X:0.2}))

    检查点(Checkpoint):Tensorflow训练模型时难免会出现中断的情况,希望能够将辛苦得到的中间参数保留下来,在训练中保存模型,习惯上称之为保存检查点。

     saver = tf.train.Saver(max_to_keep=1)  #生成saver
     saver.restore(sess2,saverdir+"linermodel.cpkt-"+str(load_epoch))
  • 相关阅读:
    poj 2312 Battle City
    poj 2002 Squares
    poj 3641 Pseudoprime numbers
    poj 3580 SuperMemo
    poj 3281 Dining
    poj 3259 Wormholes
    poj 3080 Blue Jeans
    poj 3070 Fibonacci
    poj 2887 Big String
    poj 2631 Roads in the North
  • 原文地址:https://www.cnblogs.com/wyx501/p/10541524.html
Copyright © 2011-2022 走看看