zoukankan      html  css  js  c++  java
  • tensorboard可视化

    https://www.cnblogs.com/felixwang2/p/9184404.html 

     根据上面的这个博客学习的。

    中间遇到过问题,就是运行时报错:

    Failed precondition: could not dlopen DSO: cupti64_90.dll; dlerror: cupti64_90.dll not found

     好在后来找到办法:

    将cupti64_90.dll从目录

    C:Program FilesNVIDIA GPU Computing ToolkitCUDAv9.0extrasCUPTIlibx64

    复制到设置了环境变量的这个目录下:

    C:Program FilesNVIDIA GPU Computing ToolkitCUDAv9.0in

      1 #*- coding:utf-8 -*
      2 # https://www.cnblogs.com/felixwang2/p/9184404.html
      3 # TensorFlow(八):tensorboard可视化
      4 
      5 import tensorflow as tf
      6 from tensorflow.examples.tutorials.mnist import input_data
      7 from tensorflow.contrib.tensorboard.plugins import projector
      8 
      9 # 载入数据集
     10 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
     11 # 运行次数
     12 max_steps = 1001
     13 # 图片数量
     14 image_num = 3000  # 最多10000,因为测试集为10000
     15 # 文件路径
     16 DIR = "F:/document/PyCharm/temp/"
     17 
     18 # 定义会话
     19 gpu_options = tf.GPUOptions(allow_growth=True)
     20 sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
     21 
     22 # 载入图片
     23 embedding = tf.Variable(tf.stack(mnist.test.images[:image_num]), trainable=False, name='embedding')
     24 
     25 
     26 # 参数概要
     27 def variable_summaries(var):
     28     with tf.name_scope('summaries'):
     29         mean = tf.reduce_mean(var)
     30         tf.summary.scalar('mean', mean)  # 平均值
     31         with tf.name_scope('stddev'):
     32             stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
     33         tf.summary.scalar('stddev', stddev)  # 标准差
     34         tf.summary.scalar('max', tf.reduce_max(var))  # 最大值
     35         tf.summary.scalar('min', tf.reduce_min(var))  # 最小值
     36         tf.summary.histogram('histogram', var)  # 直方图
     37 
     38 
     39 # 命名空间
     40 with tf.name_scope('input'):
     41     # 这里的none表示第一个维度可以是任意的长度
     42     x = tf.placeholder(tf.float32, [None, 784], name='x-input')
     43     # 正确的标签
     44     y = tf.placeholder(tf.float32, [None, 10], name='y-input')
     45 
     46 # 显示图片
     47 with tf.name_scope('input_reshape'):
     48     image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])  # -1表示不确定的值
     49     tf.summary.image('input', image_shaped_input, 10)  # 一共放10张图片
     50 
     51 with tf.name_scope('layer'):
     52     # 创建一个简单神经网络
     53     with tf.name_scope('weights'):
     54         W = tf.Variable(tf.zeros([784, 10]), name='W')
     55         variable_summaries(W)
     56     with tf.name_scope('biases'):
     57         b = tf.Variable(tf.zeros([10]), name='b')
     58         variable_summaries(b)
     59     with tf.name_scope('wx_plus_b'):
     60         wx_plus_b = tf.matmul(x, W) + b
     61     with tf.name_scope('softmax'):
     62         prediction = tf.nn.softmax(wx_plus_b)
     63 
     64 with tf.name_scope('loss'):
     65     # 交叉熵代价函数
     66     loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
     67     tf.summary.scalar('loss', loss)
     68 with tf.name_scope('train'):
     69     # 使用梯度下降法
     70     train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
     71 
     72 # 初始化变量
     73 sess.run(tf.global_variables_initializer())
     74 
     75 with tf.name_scope('accuracy'):
     76     with tf.name_scope('correct_prediction'):
     77         # 结果存放在一个布尔型列表中
     78         correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))  # argmax返回一维张量中最大的值所在的位置
     79     with tf.name_scope('accuracy'):
     80         # 求准确率
     81         accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  # 把correct_prediction变为float32类型
     82         tf.summary.scalar('accuracy', accuracy)
     83 
     84 # 产生metadata文件
     85 if tf.gfile.Exists(DIR + 'projector/projector/metadata.tsv'):  # 检测是否已存在
     86     tf.gfile.DeleteRecursively(DIR + 'projector/projector/metadata.tsv')
     87 with open(DIR + 'projector/projector/metadata.tsv', 'w') as f:
     88     labels = sess.run(tf.argmax(mnist.test.labels[:], 1))
     89     for i in range(image_num):
     90         f.write(str(labels[i]) + '
    ')
     91 
     92         # 合并所有的summary
     93 merged = tf.summary.merge_all()
     94 
     95 projector_writer = tf.summary.FileWriter(DIR + 'projector/projector', sess.graph)
     96 saver = tf.train.Saver()  # 用来保存网络模型
     97 config = projector.ProjectorConfig()  # 定义了配置文件
     98 embed = config.embeddings.add()
     99 embed.tensor_name = embedding.name
    100 embed.metadata_path = DIR + 'projector/projector/metadata.tsv'
    101 embed.sprite.image_path = DIR + 'projector/data/mnist_10k_sprite.png'
    102 embed.sprite.single_image_dim.extend([28, 28])
    103 projector.visualize_embeddings(projector_writer, config)  # 可视化的一个工具
    104 
    105 for i in range(max_steps):
    106     # 每个批次100个样本
    107     batch_xs, batch_ys = mnist.train.next_batch(100)
    108 
    109     run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    110     run_metadata = tf.RunMetadata()
    111 
    112     summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y: batch_ys}, options=run_options,
    113                           run_metadata=run_metadata)
    114     projector_writer.add_run_metadata(run_metadata, 'step%03d' % i)
    115     projector_writer.add_summary(summary, i)
    116 
    117     # 每训练100次打印准确率
    118     if i % 100 == 0:
    119         acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
    120         print("Iter " + str(i) + ", Testing Accuracy= " + str(acc))
    121 
    122 # 训练完保存模型
    123 saver.save(sess, DIR + 'projector/projector/a_model.ckpt', global_step=max_steps)
    124 projector_writer.close()
    125 sess.close()
    View Code

  • 相关阅读:
    mysql:there can be only one auto column...
    idea2019版搜索不到插件解决方案(亲测有效)
    Error querying database. Cause: java.lang.IllegalArgumentException:Failed to decrypt.
    Spring Security踩坑记录(静态资源放行异常)
    WebView2简单试用(五)—— 自定义用户数据文件夹
    WebView2简单试用(四)—— 使用固定版本的Edge Runtime
    WebView2简单试用(三)—— 新窗口打开页面的处理
    WebView2简单试用(二)—— 基本操作
    WebView2简单试用(一)—— 开始
    Playwright入门 —— 简介
  • 原文地址:https://www.cnblogs.com/juluwangshier/p/11427084.html
Copyright © 2011-2022 走看看