zoukankan      html  css  js  c++  java
  • 使用XGBoost实现多分类预测的实践

    使用XGBoost实现多分类预测的实践代码

    import pandas as pd
    import numpy as np
    import xgboost as xgb
    from sklearn.preprocessing import LabelEncoder
    from sklearn.model_selection import KFold
    import matplotlib.pyplot as plt
    import seaborn as sns
    import gc
    
    ## load data
    train_data = pd.read_csv('../../data/train.csv')
    test_data = pd.read_csv('../../data/test.csv')
    num_round = 1000
    
    ## category feature one_hot
    test_data['label'] = -1
    data = pd.concat([train_data, test_data])
    cate_feature = ['gender', 'cell_province', 'id_province', 'id_city', 'rate', 'term']
    for item in cate_feature:
        data[item] = LabelEncoder().fit_transform(data[item])
        item_dummies = pd.get_dummies(data[item])
        item_dummies.columns = [item + str(i + 1) for i in range(item_dummies.shape[1])]
        data = pd.concat([data, item_dummies], axis=1)
    data.drop(cate_feature,axis=1,inplace=True)
    
    train = data[data['label'] != -1]
    test = data[data['label'] == -1]
    
    ##Clean up the memory
    del data, train_data, test_data
    gc.collect()
    
    ## get train feature
    del_feature = ['auditing_date', 'due_date', 'label']
    features = [i for i in train.columns if i not in del_feature]
    
    ## Convert the label to two categories
    train_x = train[features]
    train_y = train['label'].astype(int).values
    test = test[features]
    
    params = {
        'booster': 'gbtree',
        'objective': 'multi:softmax',
        # 'objective': 'multi:softprob',   #Multiclassification probability
        'num_class': 33,
        'eval_metric': 'mlogloss',
        'gamma': 0.1,
        'max_depth': 8,
        'alpha': 0,
        'lambda': 0,
        'subsample': 0.7,
        'colsample_bytree': 0.5,
        'min_child_weight': 3,
        'silent': 0,
        'eta': 0.03,
        'nthread': -1,
        'missing': 1,
        'seed': 2019,
    }
    
    folds = KFold(n_splits=5, shuffle=True, random_state=2019)
    prob_oof = np.zeros(train_x.shape[0])
    test_pred_prob = np.zeros(test.shape[0])
    
    
    ## train and predict
    feature_importance_df = pd.DataFrame()
    for fold_, (trn_idx, val_idx) in enumerate(folds.split(train)):
        print("fold {}".format(fold_ + 1))
        trn_data = xgb.DMatrix(train_x.iloc[trn_idx], label=train_y[trn_idx])
        val_data = xgb.DMatrix(train_x.iloc[val_idx], label=train_y[val_idx])
    
        watchlist = [(trn_data, 'train'), (val_data, 'valid')]
        clf = xgb.train(params, trn_data, num_round, watchlist, verbose_eval=20, early_stopping_rounds=50)
    
        prob_oof[val_idx] = clf.predict(xgb.DMatrix(train_x.iloc[val_idx]), ntree_limit=clf.best_ntree_limit)
        fold_importance_df = pd.DataFrame()
        fold_importance_df["Feature"] = clf.get_fscore().keys()
        fold_importance_df["importance"] = clf.get_fscore().values()
        fold_importance_df["fold"] = fold_ + 1
        feature_importance_df = pd.concat([feature_importance_df, fold_importance_df], axis=0)
    
        test_pred_prob += clf.predict(xgb.DMatrix(test), ntree_limit=clf.best_ntree_limit) / folds.n_splits
    result = np.argmax(test_pred_prob, axis=1)
    
    
    ## plot feature importance
    cols = (feature_importance_df[["Feature", "importance"]].groupby("Feature").mean().sort_values(by="importance", ascending=False).index)
    best_features = feature_importance_df.loc[feature_importance_df.Feature.isin(cols)].sort_values(by='importance',ascending=False)
    plt.figure(figsize=(8, 15))
    sns.barplot(y="Feature",
                x="importance",
                data=best_features.sort_values(by="importance", ascending=False))
    plt.title('LightGBM Features (avg over folds)')
    plt.tight_layout()
    plt.savefig('../../result/xgb_importances.png')

    参考代码链接为:https://github.com/ikkyu-wen/data_mining_models,这里面的xgboost实现多分类

  • 相关阅读:
    Django Web开发学习笔记(1)
    SessionFactory 执行原生态的SQL语句
    Java中使用FileputStream导致中文乱码问题的修改方案
    JavaScript中的namespace
    SpringMVC配置文件
    Python 贝叶斯分类
    Struct(二)
    Struct2 (一)
    SpingMVC ModelAndView, Model,Control以及参数传递
    window.onload
  • 原文地址:https://www.cnblogs.com/wyhluckdog/p/12194457.html
Copyright © 2011-2022 走看看