zoukankan      html  css  js  c++  java
  • Keras自定义评估函数

     

    1. 比较一般的自定义函数:

    需要注意的是,不能像sklearn那样直接定义,因为这里的y_true和y_pred是张量,不是numpy数组。示例如下:

    from keras import backend
    
    def rmse(y_true, y_pred):
        return backend.sqrt(backend.mean(backend.square(y_pred - y_true), axis=-1))

    用的时候直接:

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse])

    2. 比较复杂的如AUC函数:

    AUC的计算需要整体数据,如果直接在batch里算,误差就比较大,不能合理反映整体情况。这里采用回调函数写法,每个epoch计算一次:

    from sklearn.metrics import roc_auc_score
    
    class roc_callback(keras.callbacks.Callback):
        def __init__(self,training_data, validation_data):
            
            self.x = training_data[0]
            self.y = training_data[1]
            self.x_val = validation_data[0]
            self.y_val = validation_data[1]
            
        
        def on_train_begin(self, logs={}):
            return
     
        def on_train_end(self, logs={}):
            return
     
        def on_epoch_begin(self, epoch, logs={}):
            return
     
        def on_epoch_end(self, epoch, logs={}):        
            y_pred = self.model.predict(self.x)
            roc = roc_auc_score(self.y, y_pred)      
            
            y_pred_val = self.model.predict(self.x_val)
            roc_val = roc_auc_score(self.y_val, y_pred_val)      
            
            print('
    roc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'
    ')
            return
     
        def on_batch_begin(self, batch, logs={}):
            return
     
        def on_batch_end(self, batch, logs={}):
            return   

    调用回调函数示例:

    model.fit(X_train, y_train, epochs=10, batch_size=4, 
              callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )

    整体示例:

    from tensorflow import keras
    from sklearn import datasets
    from sklearn import model_selection
    from sklearn.metrics import roc_auc_score
    
    def rmse(y_true, y_pred):
        return keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_pred - y_true), axis=-1))
    
    class roc_callback(keras.callbacks.Callback):
        def __init__(self,training_data, validation_data):
            
            self.x = training_data[0]
            self.y = training_data[1]
            self.x_val = validation_data[0]
            self.y_val = validation_data[1]
            
        
        def on_train_begin(self, logs={}):
            return
     
        def on_train_end(self, logs={}):
            return
     
        def on_epoch_begin(self, epoch, logs={}):
            return
     
        def on_epoch_end(self, epoch, logs={}):        
            y_pred = self.model.predict(self.x)
            roc = roc_auc_score(self.y, y_pred)      
            
            y_pred_val = self.model.predict(self.x_val)
            roc_val = roc_auc_score(self.y_val, y_pred_val)      
            
            print('
    roc-auc: %s - roc-auc_val: %s' % (str(round(roc,4)),str(round(roc_val,4))),end=100*' '+'
    ')
            return
     
        def on_batch_begin(self, batch, logs={}):
            return
     
        def on_batch_end(self, batch, logs={}):
            return   
        
        
    X, y = datasets.make_classification(n_samples=100, n_features=4, n_classes=2, random_state=2018)
    X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=2018)
    print("TrainSet", X_train.shape, "TestSet", X_test.shape)
    
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(20, input_shape=(4,), activation='relu'))
    model.add(keras.layers.Dense(1, activation='sigmoid'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[rmse])
    
    model.fit(X_train, y_train, epochs=10, batch_size=4, 
              callbacks = [roc_callback(training_data=[X_train, y_train], validation_data=[X_test, y_test])] )

    运行结果:

    TrainSet (80, 4) TestSet (20, 4)
    Epoch 1/10
    roc-auc: 0.1604 - roc-auc_val: 0.2738                                                                                                    
    80/80 [==============================] - 0s - loss: 0.8132 - rmse: 0.5298     
    Epoch 2/10
    roc-auc: 0.4874 - roc-auc_val: 0.619                                                                                                    
    80/80 [==============================] - 0s - loss: 0.7432 - rmse: 0.5049     
    Epoch 3/10
    roc-auc: 0.7715 - roc-auc_val: 0.9643                                                                                                    
    80/80 [==============================] - 0s - loss: 0.6821 - rmse: 0.4807     
    Epoch 4/10
    roc-auc: 0.9602 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.6268 - rmse: 0.4560     
    Epoch 5/10
    roc-auc: 0.9842 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.5747 - rmse: 0.4301     
    Epoch 6/10
    roc-auc: 0.9956 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.5230 - rmse: 0.4025     
    Epoch 7/10
    roc-auc: 0.9975 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.4743 - rmse: 0.3739     
    Epoch 8/10
    roc-auc: 0.9987 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.4289 - rmse: 0.3454     
    Epoch 9/10
    roc-auc: 0.9987 - roc-auc_val: 1.0...] - ETA: 0s - loss: 0.4019 - rmse: 0.3301                                                                                                    
    80/80 [==============================] - 0s - loss: 0.3830 - rmse: 0.3149     
    Epoch 10/10
    roc-auc: 0.9987 - roc-auc_val: 1.0                                                                                                    
    80/80 [==============================] - 0s - loss: 0.3424 - rmse: 0.2865  
  • 相关阅读:
    随机森林算法参数调优
    BAYES和朴素BAYES
    阿里云 金融接口 token PHP
    PHP mysql 按时间分组 表格table 跨度 rowspan
    MySql按周,按月,按日分组统计数据
    PHP 获取今日、昨日、本周、上周、本月的等等常用的起始时间戳和结束时间戳的时间处理类
    thinkphp5 tp5 会话控制 session 登录 退出 检查检验登录 判断是否应该跳转到上次url
    微信 模板消息
    php 腾讯 地图 api 计算 坐标 两点 距离 微信 网页 WebService API
    php添加http头禁止浏览器缓存
  • 原文地址:https://www.cnblogs.com/shixiangwan/p/9041707.html
Copyright © 2011-2022 走看看