zoukankan      html  css  js  c++  java
  • tensorflow笔记:模型的保存与训练过程可视化

    tensorflow笔记系列: 
    (一) tensorflow笔记:流程,概念和简单代码注释 
    (二) tensorflow笔记:多层CNN代码分析 
    (三) tensorflow笔记:多层LSTM代码分析 
    (四) tensorflow笔记:常用函数说明 
    (五) tensorflow笔记:模型的保存与训练过程可视化 
    (六)tensorflow笔记:使用tf来实现word2vec


    保存与读取模型

    在使用tf来训练模型的时候,难免会出现中断的情况。这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始。好在tf官方提供了保存和读取模型的方法。

    保存模型的方法:

    # 之前是各种构建模型graph的操作(矩阵相乘,sigmoid等等....)
    
    saver = tf.train.Saver() # 生成saver
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer()) # 先对模型初始化
    
        # 然后将数据丢入模型进行训练blablabla
    
        # 训练完以后,使用saver.save 来保存
        saver.save(sess, "save_path/file_name") #file_name如果不存在的话,会自动创建

    将模型保存好以后,载入也比较方便,如下所示:

    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        #参数可以进行初始化,也可不进行初始化。即使初始化了,初始化的值也会被restore的值给覆盖
        sess.run(tf.global_variables_initializer())     
        saver.restore(sess, "save_path/file_name") #会将已经保存的变量值resotre到 变量中。

    简单的说,就是通过saver.save来保存模型,通过saver.restore来加载模型


    使用tensorboard来使训练过程可视化

    tensorflow还提供了一个可视化工具,叫tensorboard.启动以后,可以通过网页来观察模型的结构和训练过程中各个参数的变化。如下图所示

    选区_059.png-12.7kB

    关于如何合理清楚的显示网络结构,我目前还不太搞得清楚,而且目前看来也不是太重要;但是要将训练的过程可视化还是比较方便的。简单的说,流程如下所示:

    • 使用tf.scalar_summary来收集想要显示的变量
    • 定义一个summury op, 用来汇总多个变量
    • 得到一个summy writer,指定写入路径
    • 通过summary_str = sess.run()
    # 1. 由之前的各种运算得到此批数据的loss
    loss = ..... 
    
    # 2.使用tf.scalar_summary来收集想要显示的变量,命名为loss
    tf.scalar_summary('loss',loss)  
    
    # 3.定义一个summury op, 用来汇总由scalar_summary记录的所有变量
    merged_summary_op = tf.merge_all_summaries()
    
    # 4.生成一个summary writer对象,需要指定写入路径,例如我这边就是/tmp/logdir
    summary_writer = tf.train.SummaryWriter('/tmp/logdir', sess.graph)
    
    # 开始训练,分批喂数据
    for(i in range(batch_num)):
        # 5.使用sess.run来得到merged_summary_op的返回值
        summary_str = sess.run(merged_summary_op)
    
        # 6.使用summary writer将运行中的loss值写入
        summary_writer.add_summary(summary_str,i)

    接下来,程序开始运行以后,跑到shell里运行

    $ tensorboard --logdir /tmp/logdir

    开始运行tensorboard.接下来打开浏览器,进入127.0.0.1:6006 就能够看到loss值在训练中的变化值了。

  • 相关阅读:
    使用 Log4Net 记录日志
    NuGet安装和使用
    .NET Framework 4 与 .NET Framework 4 Client Profile
    “init terminating in do_boot” Windows10 Rabbit MQ fails to start
    Ubuntu / Win7 安装db2 v10.5
    Win7下的内置FTP组件的设置详解
    c/s模式 (C#)下Ftp的多文件上传及其上传进度
    C#路径/文件/目录/I/O常见操作汇总
    C# 遍历指定目录下的所有文件及文件夹
    Mongodb主从复制 及 副本集+分片集群梳理
  • 原文地址:https://www.cnblogs.com/wuzhitj/p/6298008.html
Copyright © 2011-2022 走看看