zoukankan      html  css  js  c++  java
  • tensorboard的使用

    # coding: utf-8
    from __future__ import print_function
    from __future__ import division
    
    import tensorflow as tf
    import numpy as np
    import os
    import argparse
    
    
    def dense_to_one_hot(input_data, class_num):
        data_num = input_data.shape[0]
        index_offset = np.arange(data_num) * class_num
        labels_one_hot = np.zeros((data_num, class_num))
        labels_one_hot.flat[index_offset + input_data.ravel()] = 1
        return labels_one_hot
    
    
    def build_parser():
        parser = argparse.ArgumentParser()
        parser.add_argument('--train_path', type=str, required=True)
        parser.add_argument('--test_path', type=str, required=True)
        parser.add_argument('--model_path', type=str, required=True)
        parser.add_argument('--board_dir', type=str, required=True)
        args = parser.parse_args()
        return args
    
    
    def variable_summaries(var):
        with tf.name_scope('summaries'):
            mean = tf.reduce_mean(var)
            tf.summary.scalar('mean', mean)
            with tf.name_scope('stddev'):
                stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
            tf.summary.scalar('stddev', stddev)
            tf.summary.scalar('max', tf.reduce_max(var))
            tf.summary.scalar('min', tf.reduce_min(var))
            tf.summary.histogram('histogram', var)
    
    p = build_parser()
    if tf.gfile.Exists(p.board_dir):
        tf.gfile.DeleteRecursively(p.board_dir)
    tf.gfile.MakeDirs(p.board_dir)
    
    origin_train = np.genfromtxt(p.train_path, delimiter=',')
    data_train = origin_train[:, 0:2]
    labels_train = origin_train[:, 2]
    
    origin_test = np.genfromtxt(p.train_path, delimiter=',')
    data_test = origin_train[:, 0:2]
    labels_test = origin_train[:, 2]
    
    
    learning_rate = 0.001
    training_epochs = 5000
    display_step = 1
    
    n_features = 2
    n_class = 2
    x = tf.placeholder(tf.float32, [None, n_features], "input")
    y = tf.placeholder(tf.float32, [None, n_class])
    with tf.name_scope('W'):
        W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
        variable_summaries(W)
    with tf.name_scope('b'):
        b = tf.Variable(tf.zeros([n_class]), name="b")
        variable_summaries(b)
    
    
    scores = tf.nn.xw_plus_b(x, W, b, name='scores')
    
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
    tf.summary.scalar('cross_entropy', cost)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    
    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(os.path.join(p.board_dir, 'train'))
    test_writer = tf.summary.FileWriter(os.path.join(p.board_dir, 'test'))
    
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(training_epochs):
            _, c = sess.run([optimizer, cost],
                            feed_dict={x: data_train,
                                       y: dense_to_one_hot(labels_train.astype(int), 2)})
            if epoch % 100 == 0:
                summary, c = sess.run([merged, cost],
                                feed_dict={x: data_train,
                                           y: dense_to_one_hot(labels_train.astype(int), 2)})
                train_writer.add_summary(summary, epoch)
                test_writer.add_summary(summary, epoch)
        saver.save(sess, p.model_path)
    train_writer.close()
    test_writer.close()
    
  • 相关阅读:
    Triangle LOVE
    数据传送指令具体解释
    关于C++String字符串的使用
    TCP/IP基础(一)
    java打开目录(含推断操作系统工具类和解压缩工具类)
    hdu-1848 Fibonacci again and again
    opencv2对读书笔记——图像二值化——thresholded函数
    安卓中四种点击事件
    @MappedSuperclass注解的使用说明
    Androidclient採用Http 协议Post方式请求与服务端进行数据交互
  • 原文地址:https://www.cnblogs.com/zhouyang209117/p/8297743.html
Copyright © 2011-2022 走看看