zoukankan      html  css  js  c++  java
  • keras开发成sklearn接口

    我们可以通过包装器将Sequential模型(仅有一个输入)作为Scikit-Learn工作流的一部分,相关的包装器定义在keras.wrappers.scikit_learn.py中:

    这里有两个包装器可用:

    分类器接口:keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)

    回归器接口:keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params)

    参考文献:https://keras-cn.readthedocs.io/en/latest/scikit-learn_API/

    """
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.wrappers.scikit_learn import KerasRegressor
    import numpy as np
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import mean_squared_error
    def model(optimizer="adam"):
        #create model
        model = Sequential()
        model.add(Dense(input_dim=4,units=12,activation="relu"))
        model.add(Dense(units=8,activation="relu"))
        model.add(Dense(units=1,activation="sigmoid"))
        #compile model
        model.compile(loss="mse",optimizer=optimizer,metrics=["accuracy"],)
        return model
    #######################################################################################
    #create data
    np.random.seed(seed=10)
    X = np.random.randn(100,4)
    y = np.random.randn(100)
    
    #split data in train dataset and test dataset
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
    
    #using wrappers to create sklearn interface
    
    model = KerasRegressor(build_fn=model,epochs=10,batch_size=5)
    
    #training
    model.fit(X_train,y_train)
    #predicting
    y_pred = model.predict(X_test)
    #evalution
    print("mse:"+str(mean_squared_error(y_test,y_pred)))
    
    #cross_validation
    from sklearn.model_selection import cross_val_score
    mse = cross_val_score(estimator=model,X=X,y=y,cv=5,n_jobs=1,scoring="neg_mean_squared_error")
    print("average value of mse:"+str(mse))
    #########################################################################################
    #adjust parameters of model
    #gridSearchCV
    from sklearn.model_selection import GridSearchCV
    params = {"optimizer":['rmsprop','adam'],
              "epochs": [5,10],
              "batch_size":[5,10],
            }
    
    gridSearchCV = GridSearchCV(estimator=model,param_grid=params,cv=5)
    result = gridSearchCV.fit(X,y)
    
    result.best_params_
    result.best_score_
    #########################################################################################
    #skopt
    from skopt.space import Real,Integer,Categorical
    from skopt.utils import use_named_args
    from skopt import gp_minimize
    
    space = [Categorical(categories=['rmsprop','adam'],name="optimizer"),
             Categorical(categories=[1,2,3],name="epochs")]
    
    @use_named_args(space)
    def objective(**params):
        model.set_params(**params)
        return -np.mean(cross_val_score(model,X,y,cv=5,n_jobs=1,scoring="neg_mean_squared_error"))
    
    result = gp_minimize(objective, space, n_calls=50, random_state=0)
    print("best score:%.4f"%(result.fun))
    print("best parameters:",result.x)
    
  • 相关阅读:
    LINQ Practice
    windows下python环境安装
    kafka安装教程
    使用python连接mysql/oracle
    使用百度地图实现地图网格
    单链表反转java代码
    mysql的索引问题分析
    java中String字符串的==解析
    辗转相除法的原理
    myeclipse2014新建maven项目
  • 原文地址:https://www.cnblogs.com/wzdLY/p/9655481.html
Copyright © 2011-2022 走看看