zoukankan      html  css  js  c++  java
  • 模型融合策略:开发树模型输出叶子节点作为特征到回归器或者分类器的类

    from sklearn.base import BaseEstimator,ClassifierMixin,RegressorMixin
    from sklearn.preprocessing import OneHotEncoder
    import numpy as np
    
    class TreeLeaf(BaseEstimator,ClassifierMixin,RegressorMixin):
        """
        树模型和其他模型的结合:树模型输出的叶子节点当成特征输入到其他模型中
        """
        def __init__(self,treeModel=[],metaModel=[],n_estimators=[],goal="regression"):
            self.treeModel = treeModel
            self.metaModel = metaModel   
            self.n_estimators = n_estimators
            self.goal = goal
        
        def fit(self,X,y):
            self.best_treemodel = [] #用于保存训练参数后的tree模型   
            self.best_metamodel = [] #用于保存训练参数后的meta模型 
            self.leaf_list  = [] #用于保存叶子节点
            
            for model in self.treeModel:   
                
                model_param = model.fit(X,y) #得到训练参数后的模型
                self.best_treemodel.append(model_param)
                
                leaf = model_param.apply(X)  #输出叶子
                self.leaf_list.append(leaf)
               
            #对叶子节点进行拼接
            leaf_matrix = np.concatenate(self.leaf_list,axis=1)
            
            
            #对叶子节点进行one_hot编码
            self.one_hot_encoder = OneHotEncoder()
            x_one_hot = self.one_hot_encoder.fit_transform(leaf_matrix)
            
            #利用metaModel做拟合                  
            for model in self.metaModel:
                model_param = model.fit(x_one_hot,y)
                self.best_metamodel.append(model_param)
            
            return self
        
        def predict(self,X):
            
            leaf_list_pred = []
            
            for model in self.best_treemodel:            
                leaf_list_pred.append(model.apply(X))
                
            leaf_matrix_pred = np.concatenate(leaf_list_pred,axis=1)    
            
            x_one_hot_pred = self.one_hot_encoder.transform(leaf_matrix_pred)
            
            y_pred_list = []
            for model in self.best_metamodel:
                y_pred_list.append(model.predict(x_one_hot_pred))
            
            if self.goal == "regression":
                return sum(y_pred_list,axis=0)
            elif self.goal == "classification":  
                y_pred = np.zeros(X.shape[0])            
                for i,line in enumerate(np.array(y_pred_list).T):
                    y_pred[i] = np.argmax(np.bincount(line))
                return y_pred
    
    ##################案例测试####################################################
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.datasets import load_iris  
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    from sklearn.metrics import accuracy_score
    from lightgbm import LGBMClassifier
     
    X,y = load_iris(return_X_y=True)  
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=0)
    
    treeModel_1 = RandomForestClassifier(n_estimators=20)
    treeModel_2 = LGBMClassifier( n_estimators=30)
    #treeModel_2 = GradientBoostingClassifier(n_estimators=30)
    
    metaModel_1 = LogisticRegression()
    metaModel_2 = SVC()
    
    tl = TreeLeaf(treeModel=[treeModel_1,treeModel_2],metaModel=[metaModel_1,metaModel_2],n_estimators=[20,30],goal="classification")
    tl.fit(X_train,y_train)
    y_pred = tl.predict(X_test)
    
    accuracy_score(y_test,y_pred)

    上述代码主要完成了基于多个树模型的叶子节点输入到多个分类器或者回归器的模型融合策略,具有一定的扩展性和适应度。后面给出了一个基于随机深林和lightGBM的测试实例,供大家参考。这种模型融合策略在不同的地方效果不同,关键还是特征工程是否做得更好,该类方法在训练集上有一定的过拟合倾向。

    欢迎评论和给出意见,如果对你有帮助,请给个关注,激励一下我,谢谢!

  • 相关阅读:
    array_map()与array_shift()搭配使用 PK array_column()函数
    Educational Codeforces Round 8 D. Magic Numbers
    hdu 1171 Big Event in HDU
    hdu 2844 poj 1742 Coins
    hdu 3591 The trouble of Xiaoqian
    hdu 2079 选课时间
    hdu 2191 珍惜现在,感恩生活 多重背包入门题
    hdu 5429 Geometric Progression 高精度浮点数(java版本)
    【BZOJ】1002: [FJOI2007]轮状病毒 递推+高精度
    hdu::1002 A + B Problem II
  • 原文地址:https://www.cnblogs.com/wzdLY/p/9677784.html
Copyright © 2011-2022 走看看