zoukankan      html  css  js  c++  java
  • 机器学习-决策树与随机森林

    一、介绍

    二、实战

    1、使用决策树构建酒的数据集

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from sklearn import tree, datasets
    from sklearn.model_selection import train_test_split

    wine = datasets.load_wine()
    X = wine.data[:, :2]
    y = wine.target
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    clf = tree.DecisionTreeClassifier(max_depth=5)
    clf.fit(X_train, y_train)
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])
    x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
    y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, .02), np.arange(y_min, y_max, .02))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolors='k', s=20)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.show()

    2、随机森林

    from sklearn.ensemble import RandomForestClassifier
    wine = datasets.load_wine()
    X = wine.data[:,:2]
    y = wine.target
    X_train, X_test, y_train, y_test = train_test_split(X,y)
    forest = RandomForestClassifier(n_estimators=6,random_state=3)
    forest.fit(X_train, y_train)
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])

    x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1
    y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, .02),
    np.arange(y_min, y_max, .02))
    Z= forest.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolors='k', s=20)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title("Classifier:(max_depth = 5)")
    plt.show()

     

  • 相关阅读:
    shell十三问?
    OS + Linux nmon / nmon analyser / nmon_analyser_v52_1.zip
    nGrinder windows agent / linux agent
    java Base64
    SearchServer Elasticsearch Cluster / kibana
    db mysql / mysql cluster 5.7.19 / my.cnf / thread_pool_stall_limit
    Mininet与真实网络链接的方法
    Install ProcessMaker 3.1 or 3.2 in CentOS/RHEL 7
    软件版本GA,RC,alpha,beta,Build 含义
    paper-9-Research and Implementation of MultiPath TCP on Mobile Smart Deviceses
  • 原文地址:https://www.cnblogs.com/zhaop8078/p/9744358.html
Copyright © 2011-2022 走看看