zoukankan      html  css  js  c++  java
  • 【tf.keras】实现 F1 score、precision、recall 等 metric

    tf.keras.metric 里面竟然没有实现 F1 score、recall、precision 等指标,一开始觉得真不可思议。但这是有原因的,这些指标在 batch-wise 上计算都没有意义,需要在整个验证集上计算,而 tf.keras 在训练过程(包括验证集)中计算 acc、loss 都是一个 batch 计算一次的,最后再平均起来。Keras 2.0 版本将 precision, recall, fbeta_score, fmeasure 等 metrics 移除了。

    虽然 tf.keras.metric 中没有实现 f1 socre、precision、recall,但我们可以通过 tf.keras.callbacks.Callback 实现。即在每个 epoch 末尾,在整个 val 上计算 f1、precision、recall。

    一些博客实现了二分类下的 f1 socre、precision、recall,如下所示:

    以下代码实现了多分类下对验证集 F1 值、precision、recall 的计算,并且保存 val_f1 值最好的模型:

    import tensorflow as tf
    
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import f1_score, recall_score, precision_score
    import numpy as np
    import os
    
    
    class Metrics(tf.keras.callbacks.Callback):
        def __init__(self, valid_data):
            super(Metrics, self).__init__()
            self.validation_data = valid_data
    
        def on_epoch_end(self, epoch, logs=None):
            logs = logs or {}
            val_predict = np.argmax(self.model.predict(self.validation_data[0]), -1)
            val_targ = self.validation_data[1]
            if len(val_targ.shape) == 2 and val_targ.shape[1] != 1:
                val_targ = np.argmax(val_targ, -1)
    
            _val_f1 = f1_score(val_targ, val_predict, average='macro')
            _val_recall = recall_score(val_targ, val_predict, average='macro')
            _val_precision = precision_score(val_targ, val_predict, average='macro')
    
            logs['val_f1'] = _val_f1
            logs['val_recall'] = _val_recall
            logs['val_precision'] = _val_precision
            print(" — val_f1: %f — val_precision: %f — val_recall: %f" % (_val_f1, _val_precision, _val_recall))
            return
    
    
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=10000, random_state=32)
    
    # LeNet-5
    model = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(32, 32, 3)),
        tf.keras.layers.Conv2D(6, 5, activation='relu'),
        tf.keras.layers.AveragePooling2D(),
        tf.keras.layers.Conv2D(16, 5, activation='relu'),
        tf.keras.layers.AveragePooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(120, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(84, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    if not os.path.exists('./checkpoints'):
        os.makedirs('./checkpoints')
    
    # 按照 val_f1 保存模型
    ck_callback = tf.keras.callbacks.ModelCheckpoint('./checkpoints/weights.{epoch:02d}-{val_f1:.4f}.hdf5',
                                                     monitor='val_f1', 
                                                     mode='max', verbose=2,
                                                     save_best_only=True,
                                                     save_weights_only=True)
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=0)
    model.fit(x_train, y_train,
              validation_data=(x_val, y_val),
              epochs=100,
              callbacks=[Metrics(valid_data=(x_val, y_val)),
                         ck_callback,
                         tb_callback])
    
    

    注意 Metrics()ck_callback 两个 callback 的顺序,互换之后将报错。

    References

    How to calculate F1 Macro in Keras? -- StackOverflow
    How to compute f1 score for each epoch in Keras -- Thong Nguyen
    keras如何求分类问题中的准确率和召回率? - 鱼塘邓少的回答 - 知乎
    Keras 2.0 release notes -- keras-team/keras

  • 相关阅读:
    html 简介
    MySQL事务等了解知识
    MySQL—navicat&&练习&&pymysql
    MySQL查询表(一)
    作业
    MySQL约束&&表关系
    mysql数据类型
    初识mysql
    dll 原理解析
    又过了一天
  • 原文地址:https://www.cnblogs.com/wuliytTaotao/p/11986580.html
Copyright © 2011-2022 走看看