zoukankan      html  css  js  c++  java
  • SMOTE RF MLP demo use cross_val_score to find best argument 处理不平衡数据的demo代码 先做smote处理 再用交叉验证找到最好的模型参数 实践表明MLP更好

    # _*_coding:UTF-8_*_
    from sklearn.externals.six import StringIO  
    from sklearn import tree
    import pydot 
    import sklearn
    import numpy as np
    import sys
    import pickle
    import os
    from sklearn.cross_validation import train_test_split
    import sklearn.ensemble
    from sklearn.model_selection import cross_val_score
    # from sklearn.ensemble import ExtraTreesClassifier
    from sklearn import preprocessing
    import pdb
    from sklearn.neural_network import MLPClassifier
    from sklearn.metrics import classification_report
    from sklearn.model_selection import StratifiedShuffleSplit
    import os
    import collections
    import imblearn
    def iterbrowse(path):          
        for home, dirs, files in os.walk(path): 
            for filename in files: 
                yield os.path.join(home, filename)
    def get_data(filename):
        white_verify = []
        with open(filename) as f:
            lines = f.readlines()
            data = {}
            for line in lines:
                a = line.split("	")
                if len(a) != 78:
                    raise Exception("fuck")
                white_verify.append([float(n) for n in a[3:]])
        return white_verify
    # 显示测试结果
    def show_cm(cm, labels):
        # Compute percentanges
        percent = (cm * 100.0) / np.array(np.matrix(cm.sum(axis=1)).T)
        print 'Confusion Matrix Stats'
        for i, label_i in enumerate(labels):
            for j, label_j in enumerate(labels):
                print "%s/%s: %.2f%% (%d/%d)" % (label_i, label_j, (percent[i][j]), cm[i][j], cm[i].sum())
    def save_model_to_disk(name, model, model_dir='.'):
        serialized_model = pickle.dumps(model, protocol=pickle.HIGHEST_PROTOCOL)
        model_path = os.path.join(model_dir, name + '.model')
        print 'Storing Serialized Model to Disk (%s:%.2fMeg)' % (name, len(serialized_model) / 1024.0 / 1024.0)
        open(model_path, 'wb').write(serialized_model)
    wanted_feature = {
    15, #正向头部直方图中位数,-----H
    12, # 正向头部直方图最小,-----H
    14, #正向头部直方图平均数,-----H
    13, # 正向头部直方图最大,-----H
    16, #正向头部直方图标准差, -----H
    52, #反向头部直方图不同长度类型数, -----M
    51, #反向头部直方图平均数, --------------H
    47, #反向头部直方图最小,--------------H
    48, #反向头部直方图最大,--------------H
    49, #反向头部直方图平均数,--------------H
    50, #反向头部直方图平均数,--------------H
    23, #正向载荷直方图最大, --------------H
    24, #正向载荷直方图平均值,--------------H
    25, #正向载荷直方图中位数,--------------H
    26, #正向载荷直方图标准差,--------------H
    17, #正向头部直方图不同长度类型数,---H
    46, #反向包文的时间间隔(时间/包数), ----H
    28, #正向载荷直方图小于128字节数个数,----H
    29, #正向载荷直方图≥128、<512字节数个数,----H
    30, #正向载荷直方图≥512、<1024字节数个数,----H
    31, #正向载荷直方图>1024字节数个数,----H
    57, #x反向载荷直方图最小,--------------H
    60, #反向载荷直方图中位数,--------------H
    59, #反向载荷直方图平均值, --------------H
    61, #反向载荷直方图标准差,--------------H
    58, #反向载荷直方图最大,--------------H
    42, #反向当前流的数据包数量,
    21, #正向头部直方图大于等于40字节数个数, -----------------------H
    56, #反向头部直方图大于等于40字节数个数,------------------------H
    65, #反向载荷直方图>1024字节数个数,------------------------H
    63, #反向载荷直方图小于128字节数个数,------------------------H
    64, #反向载荷直方图≥128、<512字节数个数, ------------------------H
    66, #反向载荷直方图≥512、<1024字节数个数,------------------------H
    unwanted_features = {6, 7, 8, 41,42,43,67,68,69,70,71,72,73,74,75}
    def get_wanted_data(x):
        return x
        ans = []
        for item in x:
            #row = [data for i, data in enumerate(item) if i+6 in wanted_feature]
            row = [data for i, data in enumerate(item) if i+6 not in unwanted_features]
            #assert len(row) == len(wanted_feature)
            assert len(row) == len(x[0])-len(unwanted_features)
        return ans
    if __name__ == '__main__':
        # pdb.set_trace()
        neg_file = "cc_data/black/black_all.txt"
        pos_file = "cc_data/white/white_all.txt"
        X = []
        y = []
        if os.path.isfile(pos_file):
            if pos_file.endswith('.txt'):
                pos_set = np.genfromtxt(pos_file)
            elif pos_file.endswith('.npy'):
                pos_set = np.load(pos_file)
            y += [0] * len(pos_set)
        print("len of white X:", len(X))
        l = len(X)
        if os.path.isfile(neg_file):
            if neg_file.endswith('.txt'):
                neg_set = np.genfromtxt(neg_file)
            elif neg_file.endswith('.npy'):
                neg_set = np.load(neg_file)
            #y += [1] * (5*len(neg_set))
            y += [1] * len(neg_set)
        print("len of black X:", len(X)-l)
        print("len of X:", len(X))
        print("X sample:", X[:3])
        print("len of y:", len(y))
        print("y sample:", y[:3])
        X = [x[3:] for x in X]
        X = get_wanted_data(X)
        print("filtered X sample:", X[:1])
        black_verify = []
        for f in iterbrowse("todo/top"):
            black_verify += get_data(f)
        #ValueError: operands could not be broadcast together with shapes (1,74) (75,) (1,74)
        black_verify = get_wanted_data(black_verify)
        black_verify_labels = [1]*len(black_verify)
        white_verify = get_data("todo/white_verify.txt")
        white_verify = get_wanted_data(white_verify)
        white_verify_labels = [0]*len(white_verify)
        unknown_verify = get_data("todo/pek_feature74.txt")
        unknown_verify = get_wanted_data(unknown_verify)
        black_verify2 = get_data("todo/x_rat.txt")
        black_verify2 = get_wanted_data(black_verify2)
        black_verify_labels2 = [1]*len(black_verify2)
        # Smote use KNN, so use standard scaler
        from sklearn import preprocessing
        scaler = preprocessing.StandardScaler().fit(X)
        #scaler = preprocessing.MinMaxScaler().fit(X)
        X = scaler.transform(X)
        print("standard X sample:", X[:3])
        black_verify = scaler.transform(black_verify)
        white_verify = scaler.transform(white_verify)
        unknown_verify = scaler.transform(unknown_verify)
        black_verify2 = scaler.transform(black_verify2)
        # ValueError: operands could not be broadcast together with shapes (756140,75) (42,75) (756140,75) 
        for i in range(200): # add weight 加大必须检出数据的权重,因为只有10+个样本所以x200增多
            X = np.concatenate((X, black_verify))
            y += black_verify_labels
        y = np.array(y)
        labels = ['white', 'CC']
        #if True:
        for depth in (128, 64, 32):
          print "***"*20
          print "hidden_layer_sizes=>", depth
          sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
          for train_index, test_index in sss.split(X, y): 
            X_train, X_test = X[train_index], X[test_index]
            y_train, y_test = y[train_index], y[test_index]
            #ratio_of_train = 0.8
            #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=(1 - ratie_of_train))
            print "smote before:"
            from imblearn.over_sampling import SMOTE 
            X_train, y_train = SMOTE().fit_sample(X_train, y_train)
            print "smote after:"
            X_test2, y_test2 = SMOTE().fit_sample(X_test, y_test)
            # X_train=preprocessing.normalize(X_train)
            # X_test=preprocessing.normalize(X_test)
            from sklearn.linear_model import LogisticRegression
            clf = LogisticRegression(C=0.1, penalty='l2', tol=0.01)
            import xgboost as xgb
            clf = xgb.XGBClassifier(learning_rate=0.1,n_estimators=50,max_depth=6, objective= 'binary:logistic',nthread=40,scale_pos_weight=0.02,seed=666)  
            clf = sklearn.ensemble.RandomForestClassifier(n_estimators=100, n_jobs=10, max_depth=3, random_state=666, oob_score=True)
            clf = MLPClassifier(batch_size=128, learning_rate='adaptive', max_iter=1024, 
                                hidden_layer_sizes=(depth,), random_state=666)
            clf.fit(X_train, y_train)
            print "test confusion_matrix:"
            # print clf.feature_importances_
            y_pred = clf.predict(X_test)
            print(sklearn.metrics.confusion_matrix(y_test, y_pred))
            print(classification_report(y_test, y_pred))
            print "test confusion_matrix (SMOTE):"
            y_pred2 = clf.predict(X_test2)
            print(sklearn.metrics.confusion_matrix(y_test2, y_pred2))
            print(classification_report(y_test2, y_pred2))
            print "all confusion_matrix:"
            y_pred = clf.predict(X)
            print(sklearn.metrics.confusion_matrix(y, y_pred))
            print(classification_report(y, y_pred))
            print "black verify confusion_matrix:"
            black_verify_pred = clf.predict(black_verify)
            print(classification_report(black_verify_labels, black_verify_pred))
            print "black verify2 confusion_matrix:"
            black_verify_pred2 = clf.predict(black_verify2)
            print(classification_report(black_verify_labels2, black_verify_pred2))
            print "white verify confusion_matrix:"
            white_verify_pred = clf.predict(white_verify)
            print(classification_report(white_verify_labels, white_verify_pred))
          print "hidden_layer_sizes=>", depth
          print "***"*20
            #clf = pickle.loads(open("mpl-acc97-recall98.pkl", 'rb').read())
            clf = pickle.loads(open("mlp-add-topx10.model", 'rb').read())
            y_pred = clf.predict(X)
            print(sklearn.metrics.confusion_matrix(y, y_pred))
            print(classification_report(y, y_pred))
            import sys
        dot_data = StringIO() 
        tree.export_graphviz(clf, out_file=dot_data) 
        graph = pydot.graph_from_dot_data(dot_data.getvalue()) 
        model_name = "rf_smote"
        save_model_to_disk(model_name, clf)
        # print clf.oob_score_
        scores = cross_val_score(clf, X, y, cv=5)
        print "scores:"
        print scores


    MLP 隐藏层神经元个数 128

    test confusion_matrix (SMOTE): 测试数据的混淆矩阵
    [[131946    120]
     [   299 131767]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       1.00      1.00      1.00    132066

    avg / total       1.00      1.00      1.00    264132

    all confusion_matrix: 整体数据混淆矩阵
    [[659846    483]
     [    52  32474]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    660329
              1       0.99      1.00      0.99     32526

    avg / total       1.00      1.00      1.00    692855

    black verify confusion_matrix: #需要必须检测出来的样本 OK 都检出了
    [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
     1 1 1 1 1]
                 precision    recall  f1-score   support

              1       1.00      1.00      1.00        42

    avg / total       1.00      1.00      1.00        42

    black verify2 confusion_matrix: # 现网是黑的数据,很难区分的
    [0 0 0 0 0 0 0 1 1 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.36      0.53        11

    avg / total       1.00      0.36      0.53        11

    white verify confusion_matrix: # 现网是白的数据 很难区分的
    [1 1 1 1 0 0 0]
                 precision    recall  f1-score   support

              0       1.00      0.43      0.60         7
              1       0.00      0.00      0.00         0

    avg / total       1.00      0.43      0.60         7

    unknown_verify: # 现网采集的 有好些是黑的数据 希望检出率高 但是不能过高
    [1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 0 1 1 0 1
     0 1 1 1 1 0 0 1 0 1 0 0 0 1 1 0 1 1 0 0 1 1 0 0 0 0 0 0 0 0 1 1 1 1 0 0 1] 现网验证检出率还不错


    hidden_layer_sizes=> 64
    smote before:
    [(0, 528263), (1, 26021)]
    [(0, 132066), (1, 6505)]
    smote after:
    [(0, 528263), (1, 528263)]
    test confusion_matrix:
    [[131912    154]
     [    24   6481]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       0.98      1.00      0.99      6505

    avg / total       1.00      1.00      1.00    138571

    test confusion_matrix (SMOTE):
    [[131912    154]
     [   193 131873]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       1.00      1.00      1.00    132066

    avg / total       1.00      1.00      1.00    264132

    all confusion_matrix:
    [[659566    763]
     [    34  32492]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    660329
              1       0.98      1.00      0.99     32526

    avg / total       1.00      1.00      1.00    692855

    black verify confusion_matrix:
    [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
     1 1 1 1 1]
                 precision    recall  f1-score   support

              1       1.00      1.00      1.00        42

    avg / total       1.00      1.00      1.00        42

    black verify2 confusion_matrix:
    [0 0 0 0 0 0 0 1 1 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.36      0.53        11

    avg / total       1.00      0.36      0.53        11

    white verify confusion_matrix:
    [1 1 0 1 0 0 0]
                 precision    recall  f1-score   support

              0       1.00      0.57      0.73         7
              1       0.00      0.00      0.00         0

    avg / total       1.00      0.57      0.73         7

    [1 0 1 1 1 0 1 1 1 0 1 0 1 1 1 1 1 1 1 1 1 0 1 1 1 0 0 0 1 0 0 1 0 1 1 0 1
     0 0 1 1 1 0 0 1 1 1 1 0 0 1 1 0 1 1 1 0 0 1 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1]


    test confusion_matrix:
    [[132045     21]
     [    16   4818]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       1.00      1.00      1.00      4834

    avg / total       1.00      1.00      1.00    136900

    test confusion_matrix (SMOTE):
    [[132045     21]
     [   246 131820]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       1.00      1.00      1.00    132066

    avg / total       1.00      1.00      1.00    264132

    all confusion_matrix:
    [[660227    102]
     [    29  24139]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    660329
              1       1.00      1.00      1.00     24168

    avg / total       1.00      1.00      1.00    684497

    black verify confusion_matrix:
    [0 1 0 0 1 1 1 1 1 1 1 0 0 1 0 1 1 1 1 0 0 1 0 0 1 1 1 0 0 0 0 1 1 1 1 1 1
     1 1 1 1 1] 这个是必须要全部检出的
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.67      0.80        42

    avg / total       1.00      0.67      0.80        42

    white verify confusion_matrix:
    [0 0 0 0 0 0 1]
                 precision    recall  f1-score   support

              0       1.00      0.86      0.92         7
              1       0.00      0.00      0.00         0

    avg / total       1.00      0.86      0.92         7

    unknown_verify: 现网的检出太低了!过拟合比较严重。。。。
    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
     0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]


    test confusion_matrix:
    [[132038     28]
     [    16   4818]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       0.99      1.00      1.00      4834

    avg / total       1.00      1.00      1.00    136900

    test confusion_matrix (SMOTE):
    [[132038     28]
     [   257 131809]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       1.00      1.00      1.00    132066

    avg / total       1.00      1.00      1.00    264132

    all confusion_matrix:
    [[660220    109]
     [    34  24134]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    660329
              1       1.00      1.00      1.00     24168

    avg / total       1.00      1.00      1.00    684497

    black verify confusion_matrix:
    [1 1 0 0 1 1 1 1 1 1 0 0 0 0 0 1 1 1 1 0 1 0 0 0 1 1 1 0 0 0 0 1 1 1 1 1 1
     1 1 1 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.64      0.78        42

    avg / total       1.00      0.64      0.78        42

    white verify confusion_matrix:
    [0 0 0 0 0 1 1]
                 precision    recall  f1-score   support

              0       1.00      0.71      0.83         7
              1       0.00      0.00      0.00         0

    avg / total       1.00      0.71      0.83         7

    [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
     0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]


    test confusion_matrix (SMOTE):
    [[132037     29]
     [   301 131765]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       1.00      1.00      1.00    132066

    avg / total       1.00      1.00      1.00    264132

    all confusion_matrix:
    [[660217    112]
     [    36  24132]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    660329
              1       1.00      1.00      1.00     24168

    avg / total       1.00      1.00      1.00    684497

    black verify confusion_matrix:
    [0 1 0 1 1 1 1 1 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 1 1 1 0 0 0 0 1 0 1 1 1 1
     0 1 1 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.55      0.71        42

    avg / total       1.00      0.55      0.71        42

    white verify confusion_matrix:
    [0 0 0 0 0 1 1]
                 precision    recall  f1-score   support

              0       1.00      0.71      0.83         7
              1       0.00      0.00      0.00         0

    avg / total       1.00      0.71      0.83         7

    [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
     0 1 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]



    test confusion_matrix (SMOTE):
    [[114699  17367]
     [ 11921 120145]]
                 precision    recall  f1-score   support

              0       0.91      0.87      0.89    132066
              1       0.87      0.91      0.89    132066

    avg / total       0.89      0.89      0.89    264132

    all confusion_matrix:
    [[573083  87246]
     [  2877  29649]]
                 precision    recall  f1-score   support

              0       1.00      0.87      0.93    660329
              1       0.25      0.91      0.40     32526

    avg / total       0.96      0.87      0.90    692855

    black verify confusion_matrix:
    [1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 0 0
     1 1 0 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.88      0.94        42

    avg / total       1.00      0.88      0.94        42

    black verify2 confusion_matrix:
    [1 1 0 0 0 0 0 1 1 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.55      0.71        11

    avg / total       1.00      0.55      0.71        11

    white verify confusion_matrix:
    [1 1 1 1 1 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         7
              1       0.00      0.00      0.00         0

    avg / total       0.00      0.00      0.00         7

    [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
     0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


    [[132018     48]
     [    11   6494]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       0.99      1.00      1.00      6505

    avg / total       1.00      1.00      1.00    138571

    test confusion_matrix (SMOTE):
    [[132018     48]
     [    82 131984]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    132066
              1       1.00      1.00      1.00    132066

    avg / total       1.00      1.00      1.00    264132

    all confusion_matrix:
    [[660134    195]
     [    29  32497]]
                 precision    recall  f1-score   support

              0       1.00      1.00      1.00    660329
              1       0.99      1.00      1.00     32526

    avg / total       1.00      1.00      1.00    692855

    black verify confusion_matrix:
    [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
     1 1 1 1 1]
                 precision    recall  f1-score   support

              1       1.00      1.00      1.00        42

    avg / total       1.00      1.00      1.00        42

    black verify2 confusion_matrix:
    [0 0 0 0 0 0 0 1 0 1 1]
                 precision    recall  f1-score   support

              0       0.00      0.00      0.00         0
              1       1.00      0.27      0.43        11

    avg / total       1.00      0.27      0.43        11

    white verify confusion_matrix:
    [0 0 1 0 1 0 1]
                 precision    recall  f1-score   support

              0       1.00      0.57      0.73         7
              1       0.00      0.00      0.00         0

    avg / total       1.00      0.57      0.73         7

    [0 0 0 0 0 0 0 1 0 0 1 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 1 1 0 0
     0 1 1 1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0]

  • 相关阅读:
    Mysql基础(九):MySQL 事务
    java 基本语法(十九)Optional类的使用
    java 基本语法(十八)Lambda (五)Stream API
    java 基本语法(十七)Lambda (四)构造器引用与数组引用
    java 基本语法(十六)Lambda (三)函数式接口
  • 原文地址:https://www.cnblogs.com/bonelee/p/9092499.html
Copyright © 2011-2022 走看看