zoukankan      html  css  js  c++  java
  • Keras自定义评估函数

     

    1. 比较一般的自定义函数:

    需要注意的是,不能像sklearn那样直接定义,因为这里的y_true和y_pred是张量,不是numpy数组。示例如下:

    from keras import backend
    
    def rmse(y_true, y_pred):
        return backend.sqrt(backend.mean(backend.square(y_pred - y_true), axis=-1))

    用的时候直接:

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse])

    2. 比较复杂的如AUC函数:

    AUC的计算需要整体数据,如果直接在batch里算,误差就比较大,不能合理反映整体情况。这里采用回调函数写法,每个epoch计算一次:

    from sklearn.metrics import roc_auc_score
    
    class roc_callback(keras.callbacks.Callback):
        def __init__(self,training_data, validation_data):
            
            self.x = training_data[0]
            self.y = training_data[1]
            self.x_val = validation_data[0]
            self.y_val = validation_data[1]
            
        
        def on_train_begin(self, logs={}):
            return
     
        def on_train_end(self, logs={}):
            return
     
        def on_epoch_begin(self, epoch, logs={}):
            return
     
        def on_epoch_end(self, epoch, logs={}):        
            y_pred = self.model.predict(self.x)
            roc = roc_auc_score(self.y, y_pred)      
            
            y_pred_val = self.model.predict(self.x_val)
            roc_val = roc_auc_score(self.y_val, y_pred_val)      
            
            print('
    roc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'
    ')
            return
     
        def on_batch_begin(self, batch, logs={}):
            return
     
        def on_batch_end(self, batch, logs={}):
            return   

    调用回调函数示例:

    model.fit(X_train, y_train, epochs=10, batch_size=4, 
              callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )

    整体示例:

    from tensorflow import keras
    from sklearn import datasets
    from sklearn import model_selection
    from sklearn.metrics import roc_auc_score
    
    def rmse(y_true, y_pred):
        return keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_pred - y_true), axis=-1))
    
    class roc_callback(keras.callbacks.Callback):
        def __init__(self,training_data, validation_data):
            
            self.x = training_data[0]
            self.y = training_data[1]
            self.x_val = validation_data[0]
            self.y_val = validation_data[1]
            
        
        def on_train_begin(self, logs={}):
            return
     
        def on_train_end(self, logs={}):
            return
     
        def on_epoch_begin(self, epoch, logs={}):
            return
     
        def on_epoch_end(self, epoch, logs={}):        
            y_pred = self.model.predict(self.x)
            roc = roc_auc_score(self.y, y_pred)      
            
            y_pred_val = self.model.predict(self.x_val)
            roc_val = roc_auc_score(self.y_val, y_pred_val)      
            
            print('
    roc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'
    ')
            return
     
        def on_batch_begin(self, batch, logs={}):
            return
     
        def on_batch_end(self, batch, logs={}):
            return   
        
        
    X, y = datasets.make_classification(n_samples=100, n_features=4, n_classes=2, random_state=2018)
    X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=2018)
    print("TrainSet", X_train.shape, "TestSet", X_test.shape)
    
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(20, input_shape=(4,), activation='relu'))
    model.add(keras.layers.Dense(1, activation='sigmoid'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse])
    
    model.fit(X_train, y_train, epochs=10, batch_size=4, 
              callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )

    运行结果:

    TrainSet (80, 4) TestSet (20, 4)
    Epoch 1/10
    roc-auc: 0.1604 - roc-auc_val: 0.2738                                                                                                    
    80/80 [==============================] - 0s - loss: 0.8132 - rmse: 0.5298     
    Epoch 2/10
    roc-auc: 0.4874 - roc-auc_val: 0.619                                                                                                    
    80/80 [==============================] - 0s - loss: 0.7432 - rmse: 0.5049     
    Epoch 3/10
    roc-auc: 0.7715 - roc-auc_val: 0.9643                                                                                                    
    80/80 [==============================] - 0s - loss: 0.6821 - rmse: 0.4807     
    Epoch 4/10
    roc-auc: 0.9602 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.6268 - rmse: 0.4560     
    Epoch 5/10
    roc-auc: 0.9842 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.5747 - rmse: 0.4301     
    Epoch 6/10
    roc-auc: 0.9956 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.5230 - rmse: 0.4025     
    Epoch 7/10
    roc-auc: 0.9975 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.4743 - rmse: 0.3739     
    Epoch 8/10
    roc-auc: 0.9987 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.4289 - rmse: 0.3454     
    Epoch 9/10
    roc-auc: 0.9987 - roc-auc_val: 1.0...] - ETA: 0s - loss: 0.4019 - rmse: 0.3301                                                                                                    
    80/80 [==============================] - 0s - loss: 0.3830 - rmse: 0.3149     
    Epoch 10/10
    roc-auc: 0.9987 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.3424 - rmse: 0.2865  
  • 相关阅读:
    9"边界匹配
    8劈分
    7替换
    5逻辑匹配
    4分组匹配
    3贪婪匹配与勉强匹配
    python多线程之线程传参
    多线程(类的形式)---线程同步
    多线程基础
    Linux----黑马程序员Linux教学视频简记(转载)
  • 原文地址:https://www.cnblogs.com/shixiangwan/p/9041707.html
Copyright © 2011-2022 走看看