zoukankan      html  css  js  c++  java
  • 解决TensorBoard训练集和测试集指标只能分开显示的问题(基于Keras)

    参考https://stackoverflow.com/questions/47877475/keras-tensorboard-plot-train-and-validation-scalars-in-a-same-figure
    tensorflow版本:1.13.1
    keras版本:2.2.4
    重新写一个TrainValTensorBoard继承TensorBoard。

    import os
    import tensorflow as tf
    from keras.callbacks import TensorBoard
    
    class TrainValTensorBoard(TensorBoard):
        def __init__(self, log_dir='./logs', **kwargs):
            # Make the original `TensorBoard` log to a subdirectory 'training'
            training_log_dir = os.path.join(log_dir, 'training')
            super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
    
            # Log the validation metrics to a separate subdirectory
            self.val_log_dir = os.path.join(log_dir, 'validation')
    
        def set_model(self, model):
            # Setup writer for validation metrics
            self.val_writer = tf.summary.FileWriter(self.val_log_dir)
            super(TrainValTensorBoard, self).set_model(model)
    
        def on_epoch_end(self, epoch, logs=None):
            # Pop the validation logs and handle them separately with
            # `self.val_writer`. Also rename the keys so that they can
            # be plotted on the same figure with the training metrics
            logs = logs or {}
            val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
            for name, value in val_logs.items():
                summary = tf.Summary()
                summary_value = summary.value.add()
                summary_value.simple_value = value.item()
                summary_value.tag = name
                self.val_writer.add_summary(summary, epoch)
            self.val_writer.flush()
    
            # Pass the remaining logs to `TensorBoard.on_epoch_end`
            logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
            super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)
    
        def on_train_end(self, logs=None):
            super(TrainValTensorBoard, self).on_train_end(logs)
            self.val_writer.close()
    

    使用新的TrainValTensorBoard。

    from keras.models import Sequential
    from keras.layers import Dense
    from keras.datasets import mnist
    
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.reshape(60000, 784)
    x_test = x_test.reshape(10000, 784)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    
    model = Sequential()
    model.add(Dense(64, activation='relu', input_shape=(784,)))
    model.add(Dense(10, activation='softmax'))
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    model.fit(x_train, y_train, epochs=10,
              validation_data=(x_test, y_test),
              callbacks=[TrainValTensorBoard(write_graph=False)])
    

  • 相关阅读:
    父进程pid和子进程pid的大小关系
    static 和extern关键字
    linux源码下载
    tar命令
    USB开发——内核USB驱动+libusb开发方法
    microchip PIC芯片使用方法
    android下4G上网卡
    Modem常用概念
    4G上网卡NIDS拨号之Rmnet驱动
    Uboot源码解析
  • 原文地址:https://www.cnblogs.com/zuotongbin/p/11821906.html
Copyright © 2011-2022 走看看