zoukankan      html  css  js  c++  java
  • tensorflow笔记(三)之 tensorboard的使用

    tensorflow笔记(三)之 tensorboard的使用

    版权声明:本文为博主原创文章,转载请指明转载地址

    http://www.cnblogs.com/fydeblog/p/7429344.html

    前言

    这篇博客将介绍tensorflow当中一个非常有用的可视化工具tensorboard的使用,它将对我们分析训练效果,理解训练框架和优化算法有很大的帮助

    还记得我的第一篇tensorflow博客上的的例子吗?这篇博客会以第一篇tensorflow博客的tensorboard图为例进行展开。

    我会把这篇博客的相关代码(代码也会贴在博客上,可以直接copy生成py文件用)和notebook放在文后的百度云链接上,欢迎下载!

    1. 实践1--矩阵相乘

    相应的代码

     1 import tensorflow as tf
     2 
     3 with tf.name_scope('graph') as scope:
     4      matrix1 = tf.constant([[3., 3.]],name ='matrix1')  #1 row by 2 column
     5      matrix2 = tf.constant([[2.],[2.]],name ='matrix2') # 2 row by 1 column
     6      product = tf.matmul(matrix1, matrix2,name='product')
     7   
     8 sess = tf.Session()
     9 
    10 writer = tf.summary.FileWriter("logs/", sess.graph)
    11 
    12 init = tf.global_variables_initializer()
    13 
    14 sess.run(init)

    这里相对于第一篇tensorflow多了一点东西,tf.name_scope函数是作用域名,上述代码斯即在graph作用域op下,又有三个op(分别是matrix1,matrix2,product),用tf函数内部的name参数命名,这样会在tensorboard中显示,具体图像还请看下面。

    很重要运行上面的代码,查询当前目录,就可以找到一个新生成的文件,已命名为logs,我们需在终端上运行tensorboard,生成本地链接,具体看我截图,当然你也可以将上面的代码直接生成一个py文档在终端运行,也会在终端当前目录生成一个logs文件,然后运行tensorboard --logdir logs指令,就可以生成一个链接,复制那个链接,在google浏览器(我试过火狐也行)粘贴显示,对于tensorboard 中显示的网址打不开的朋友们, 请使用 http://localhost:6006 (如果这个没有成功,我之前没有安装tensorboard,也出现链接,但那个链接点开什么都没有,所以还有一种可能就是你没有安装tensorboard,使用pip install tensorboard安装tensorboard,python3用pip3 install tensorboard)

    具体运行过程如下(中间的警告请忽略,我把上面的代码命名为1.py运行的)

    可以看到最后一行出现了链接,复制那个链接,推荐用google浏览器打开(火狐我试过也行),也可以直接打开链接http://localhost:6006,这样更方便!

    如果出现下图,则证明打开成功

    2. 实践2---线性拟合(一)

    上面那一个是小试牛刀,比较简单,没有任何训练过程,下面将第一篇tensorflow笔记中的第二个例子来画出它的流动图(哦,对了,之所有说是流动图,这是由于tensorflow的名字就是张量在图形中流动的意思)

    代码如下:(命名文件2.py)

     1 import tensorflow as tf
     2 import numpy as np
     3 
     4 ## prepare the original data
     5 with tf.name_scope('data'):
     6      x_data = np.random.rand(100).astype(np.float32)
     7      y_data = 0.3*x_data+0.1
     8 ##creat parameters
     9 with tf.name_scope('parameters'):
    10      weight = tf.Variable(tf.random_uniform([1],-1.0,1.0))
    11      bias = tf.Variable(tf.zeros([1]))
    12 ##get y_prediction
    13 with tf.name_scope('y_prediction'):
    14      y_prediction = weight*x_data+bias
    15 ##compute the loss
    16 with tf.name_scope('loss'):
    17      loss = tf.reduce_mean(tf.square(y_data-y_prediction))
    18 ##creat optimizer
    19 optimizer = tf.train.GradientDescentOptimizer(0.5)
    20 #creat train ,minimize the loss 
    21 with tf.name_scope('train'):
    22      train = optimizer.minimize(loss)
    23 #creat init
    24 with tf.name_scope('init'): 
    25      init = tf.global_variables_initializer()
    26 ##creat a Session 
    27 sess = tf.Session()
    28 ##initialize
    29 writer = tf.summary.FileWriter("logs/", sess.graph)
    30 sess.run(init)
    31 ## Loop
    32 for step  in  range(101):
    33     sess.run(train)
    34     if step %10==0 :
    35         print step ,'weight:',sess.run(weight),'bias:',sess.run(bias)

    运行这个程序会打印一些东西,看过第一篇tensorflow笔记的人应该知道

    具体输出如下:

    具体的运行过程如下图,跟第一个差不多

    打开链接后,会出现下图

    这个就是上面代码的流动图,先初始化参数,算出预测,计算损失,然后训练,更新相应的参数。

    当然这个图还可以进行展开,里面有更详细的流动(截图无法全面,还请自己运行出看看哦)

    Parameters部分

    y_prediction部分和init部分

    loss部分

    还有最后的train部分

    具体东西还请自己展开看看,不难理解

    2. 实践2---线性拟合(二)

    我们在对上面的代码进行再修改修改,试试tensorboard的其他功能,例如scalars,distributions,histograms,它们对我们分析学习算法的性能有很大帮助。

    代码如下:(命名文件3.py)

     1 import tensorflow as tf
     2 import numpy as np
     3 
     4 ## prepare the original data
     5 with tf.name_scope('data'):
     6      x_data = np.random.rand(100).astype(np.float32)
     7      y_data = 0.3*x_data+0.1
     8 ##creat parameters
     9 with tf.name_scope('parameters'):
    10      with tf.name_scope('weights'):
    11             weight = tf.Variable(tf.random_uniform([1],-1.0,1.0))
    12            tf.summary.histogram('weight',weight)
    13      with tf.name_scope('biases'):
    14            bias = tf.Variable(tf.zeros([1]))
    15            tf.summary.histogram('bias',bias)
    16 ##get y_prediction
    17 with tf.name_scope('y_prediction'):
    18      y_prediction = weight*x_data+bias
    19 ##compute the loss
    20 with tf.name_scope('loss'):
    21      loss = tf.reduce_mean(tf.square(y_data-y_prediction))
    22      tf.summary.scalar('loss',loss)
    23 ##creat optimizer
    24 optimizer = tf.train.GradientDescentOptimizer(0.5)
    25 #creat train ,minimize the loss 
    26 with tf.name_scope('train'):
    27      train = optimizer.minimize(loss)
    28 #creat init
    29 with tf.name_scope('init'): 
    30      init = tf.global_variables_initializer()
    31 ##creat a Session 
    32 sess = tf.Session()
    33 #merged
    34 merged = tf.summary.merge_all()
    35 ##initialize
    36 writer = tf.summary.FileWriter("logs/", sess.graph)
    37 sess.run(init)
    38 ## Loop
    39 for step  in  range(101):
    40     sess.run(train)
    41     rs=sess.run(merged)
    42     writer.add_summary(rs, step)

    闲麻烦,我把打印的去掉了,这里多了几个函数,tf.histogram(对应tensorboard中的scalar)和tf.scalar函数(对应tensorboard中的distribution和histogram)是制作变化图表的,两者差不多,使用方式可以参考上面代码,一般是第一项字符命名,第二项就是要记录的变量了,最后用tf.summary.merge_all对所有训练图进行合并打包,最后必须用sess.run一下打包的图,并添加相应的记录。

     运行过程与上面两个一样

    下面来看看tensorboard中的训练图吧

    scalar中的loss训练图

    distribution中的weight和bias的训练图

    histogram中的weight和bias的训练图

    我们可以根据训练图,对学习情况进行评估,比如我们看损失训练图,可以看到现在是一条慢慢减小的曲线,最后的值趋近趋近于0(这里趋近于0是由于我选的模型太容易训练了,误差可以逼近0,同时又能很好的表征系统的模型,在现实情况,往往都有误差,趋近于0反而是过拟合),这符合本意,就是要最小化loss,如果loss的曲线最后没有平滑趋近一个数,则说明训练的力度还不够,还有加大次数,如果loss还很大,说明学习算法不太理想,需改变当前的算法,去实现更小的loss,另外两幅图与loss类似,最后都是要趋近一个数的,没有趋近和上下浮动都是有问题的。

    结尾

    tensorboard的博客结束了,我写的只是基础部分,更多东西还请看官方的文档和教程,希望这篇博客能对你学习tensorboard有帮助!

    notebook链接: https://pan.baidu.com/s/1o8lzN1g 密码: mbv8

     

  • 相关阅读:
    JavaWeb的三大作用域
    软件工程最后一次作业
    软件工程第四次作业
    软件工程第三次作业
    软件工程第二次作业
    2020软件工程第一次作业
    新建Maven项目报错:Cannot resolve plugin org.apache.maven.plugins:maven-clean-plugin:x.x
    浅谈C++ STL
    C++中几种输入输出cin、cin.getline()、getline()、sscanf()、sprintf()、gets()等
    包含头文件的问题之1.7编程基础之字符串 24:单词的长度
  • 原文地址:https://www.cnblogs.com/fydeblog/p/7429344.html
Copyright © 2011-2022 走看看