zoukankan      html  css  js  c++  java
  • TensorFlow保存和载入模型

    首先定义一个tf.train.Saver类:

    saver = tf.train.Saver(max_to_keep=1)

    其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后5个模型,如果设置成0,训练过程中的所有模型都会被保存。

    模型训练好以后,保存模型:

    saver.save(sess, ckpt_dir + "/nn_model.ckpt", global_step=1)

    其中,sess是Session,ckpt_dir + "/nn_model.ckpt"是保存的路径和名称,global_step是模型名称的后缀名,由于我们只保存最后一个模型,所以可以设置为1,如果每一个模型都想保存,可以设置成训练的epoch。

    载入模型比较简单:

    saver.restore(sess, model_file)

    其中,sess是Session,model_file是模型的路径和名称。

  • 相关阅读:
    python数字
    Python数据类型
    Python表达式与运算符
    正则表达式
    计划任务
    nfs服务
    nginx反向代理+负载均衡
    samba安装测试
    自定义centos7 yum仓库
    token过期时间
  • 原文地址:https://www.cnblogs.com/mstk/p/9395589.html
Copyright © 2011-2022 走看看