zoukankan      html  css  js  c++  java
  • TensorFlow——TensorBoard可视化

    TensorFlow提供了一个可视化工具TensorBoard,它能够将训练过程中的各种绘制数据进行展示出来,包括标量,图片,音频,计算图,数据分布,直方图等,通过网页来观察模型的结构和训练过程中各个参数的变化。

    Tensorboard通过一个日志展示系统进行数据可视化,在session运行图的时候,将各类的数据汇总并输出到日志文件中。然后启动Tensorboard服务,Tensorboard读取日志文件,并开启6006端口提供web服务。让用户可以在浏览器中查看数据。

    相关的API函数如下;

    tf.summary.scalar() :标量数据汇总,输出protobuf

    tf.summary.histogram() :记录变量var的直方图,输出到直方图汇总的protobuf

    tf.summary.image() :图像数据汇总,输出protobuf

    tf.summary.merge() :合并所有的汇总日志

    tf.summary.FileWriter() :创建SummaryWriter

    tf.summary.FileWriter().add_summary()

    tf.summary.FileWriter().add_session_log()

    tf.summary.FileWriter().add_event()

    tf.summary.FileWriter().add_graph() : 将protobuf写入文件的类

    代码如下:

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    train_x = np.linspace(-5, 3, 50)
    train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
    
    # plt.plot(train_x, train_y, 'r.')
    # plt.grid(True)
    # plt.show()
    
    X = tf.placeholder(dtype=tf.float32)
    Y = tf.placeholder(dtype=tf.float32)
    
    w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
    b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
    
    z = tf.multiply(X, w) + b
    
    tf.summary.histogram('z', z)
    
    cost = tf.reduce_mean(tf.square(Y - z))
    
    tf.summary.scalar('loss', cost)
    
    learning_rate = 0.01
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    
    init = tf.global_variables_initializer()
    
    training_epochs = 20
    display_step = 2
    
    
    with tf.Session() as sess:
        sess.run(init)
        loss_list = []
    
        merged_summary_op = tf.summary.merge_all()  # 合并所有的summary
        summary_wirter = tf.summary.FileWriter('log/linear', sess.graph)
    
        for epoch in range(training_epochs):
            for (x, y) in zip(train_x, train_y):
                sess.run(optimizer,feed_dict={X:x, Y:y})
    
            if epoch % display_step == 0:
                loss = sess.run(cost, feed_dict={X:x, Y:y})
                loss_list.append(loss)
                print('Iter: ', epoch, ' Loss: ', loss)
            summary_str = sess.run(merged_summary_op, feed_dict={X:train_x, Y:train_y})
            summary_wirter.add_summary(summary_str, epoch)
    
        w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
        print(" Finished ")
        print("W: ", w_, " b: ", b_, " loss: ", loss)
        plt.plot(train_x, train_x*w_ + b_, 'g-', train_x, train_y, 'r.')
        plt.grid(True)
        plt.show()

    上述的可视化步骤主要是

      1.将需要可视化的变量加入summary,做好可视化的定义操作

      2.merged_summary_op = tf.summary.merge_all() # 合并所有的summary

      3.创建summary_wirter对象,并将图写入文件

      4.获取可视化的数据,通过summary_writer对象将数据进行写入

     

    在程序运行完,将会在指定好的路径中生成日志文件,

    通过命令行工具切换到该目录,

     

    执行命令:tensorboard --logdir=生成的日志文件的路径

    打开浏览器进行查看,

     

    定义的图的结构:

  • 相关阅读:
    java.lang.NoSuchMethodError: org.springframework.web.context.request.ServletRequestAttributes.<init>
    eclipse web项目实际工程路径对应
    java中专业术语详解
    Maven详解
    工作常用
    html页面布局
    jQuery易混淆概念的区别
    Jquery Datagrid
    Jquery EasyUI 动态添加标签页(Tabs)
    sql语句的写法
  • 原文地址:https://www.cnblogs.com/baby-lily/p/10931302.html
Copyright © 2011-2022 走看看