zoukankan      html  css  js  c++  java
  • 使用 ID3 对 Titanic 进行决策树分类

    原创转载请注明出处:https://www.cnblogs.com/agilestyle/p/12722688.html

    过程划分

    数据加载

    import graphviz
    import numpy as np
    import pandas as pd
    from sklearn import tree
    from sklearn.feature_extraction import DictVectorizer
    from sklearn.model_selection import cross_val_score
    from sklearn.tree import DecisionTreeClassifier
    
    # 数据加载
    train_data = pd.read_csv(r'/data/Titanic/train.csv')
    test_data = pd.read_csv(r'/data/Titanic/train.csv')

    数据探索

    # 数据探索
    print('-' * 30)
    print(train_data.info())
    print('-' * 30)
    print(train_data.describe())
    print('-' * 30)
    print(train_data.describe(include=['O']))
    print('-' * 30)
    print(train_data.head())
    print('-' * 30)
    print(train_data.tail())

    Console Output

    ------------------------------
    <class 'pandas.core.frame.DataFrame'>
    RangeIndex: 891 entries, 0 to 890
    Data columns (total 12 columns):
    PassengerId    891 non-null int64
    Survived       891 non-null int64
    Pclass         891 non-null int64
    Name           891 non-null object
    Sex            891 non-null object
    Age            714 non-null float64
    SibSp          891 non-null int64
    Parch          891 non-null int64
    Ticket         891 non-null object
    Fare           891 non-null float64
    Cabin          204 non-null object
    Embarked       889 non-null object
    dtypes: float64(2), int64(5), object(5)
    memory usage: 83.7+ KB
    None
    ------------------------------
           PassengerId    Survived      Pclass         Age       SibSp  
    count   891.000000  891.000000  891.000000  714.000000  891.000000   
    mean    446.000000    0.383838    2.308642   29.699118    0.523008   
    std     257.353842    0.486592    0.836071   14.526497    1.102743   
    min       1.000000    0.000000    1.000000    0.420000    0.000000   
    25%     223.500000    0.000000    2.000000   20.125000    0.000000   
    50%     446.000000    0.000000    3.000000   28.000000    0.000000   
    75%     668.500000    1.000000    3.000000   38.000000    1.000000   
    max     891.000000    1.000000    3.000000   80.000000    8.000000   
    
                Parch        Fare  
    count  891.000000  891.000000  
    mean     0.381594   32.204208  
    std      0.806057   49.693429  
    min      0.000000    0.000000  
    25%      0.000000    7.910400  
    50%      0.000000   14.454200  
    75%      0.000000   31.000000  
    max      6.000000  512.329200  
    ------------------------------
                                                       Name   Sex  Ticket Cabin  
    count                                               891   891     891   204   
    unique                                              891     2     681   147   
    top     Lobb, Mrs. William Arthur (Cordelia K Stanlick)  male  347082    G6   
    freq                                                  1   577       7     4   
    
           Embarked  
    count       889  
    unique        3  
    top           S  
    freq        644  
    ------------------------------
       PassengerId  Survived  Pclass  
    0            1         0       3   
    1            2         1       1   
    2            3         1       3   
    3            4         1       1   
    4            5         0       3   
    
                                                    Name     Sex   Age  SibSp  
    0                            Braund, Mr. Owen Harris    male  22.0      1   
    1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   
    2                             Heikkinen, Miss. Laina  female  26.0      0   
    3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   
    4                           Allen, Mr. William Henry    male  35.0      0   
    
       Parch            Ticket     Fare Cabin Embarked  
    0      0         A/5 21171   7.2500   NaN        S  
    1      0          PC 17599  71.2833   C85        C  
    2      0  STON/O2. 3101282   7.9250   NaN        S  
    3      0            113803  53.1000  C123        S  
    4      0            373450   8.0500   NaN        S  
    ------------------------------
         PassengerId  Survived  Pclass                                      Name  
    886          887         0       2                     Montvila, Rev. Juozas   
    887          888         1       1              Graham, Miss. Margaret Edith   
    888          889         0       3  Johnston, Miss. Catherine Helen "Carrie"   
    889          890         1       1                     Behr, Mr. Karl Howell   
    890          891         0       3                       Dooley, Mr. Patrick   
    
            Sex   Age  SibSp  Parch      Ticket   Fare Cabin Embarked  
    886    male  27.0      0      0      211536  13.00   NaN        S  
    887  female  19.0      0      0      112053  30.00   B42        S  
    888  female   NaN      1      2  W./C. 6607  23.45   NaN        S  
    889    male  26.0      0      0      111369  30.00  C148        C  
    890    male  32.0      0      0      370376   7.75   NaN        Q  

    数据清洗

    # 数据清洗
    # 使用平均年龄来填充年龄中的 nan 值
    train_data['Age'].fillna(train_data['Age'].mean(), inplace=True)
    test_data['Age'].fillna(test_data['Age'].mean(), inplace=True)
    # 使用票价的均值填充票价中的 nan 值
    train_data['Fare'].fillna(train_data['Fare'].mean(), inplace=True)
    test_data['Fare'].fillna(test_data['Fare'].mean(), inplace=True)
    # 使用登录最多的港口来填充登录港口的 nan 值
    print(train_data['Embarked'].value_counts())
    train_data['Embarked'].fillna('S', inplace=True)
    test_data['Embarked'].fillna('S', inplace=True)

    特征选择

    特征选择是分类器的关键。特征选择不同,得到的分类器也不同。可以通过数据探索发现来选择哪些特征做生存的预测。PassengerId 为乘客编号,对分类没有作用,可以放弃;Name 为乘客姓名,对分类没有作用,可以放弃;Cabin 字段缺失值太多,可以放弃;Ticket 字段为船票号码,杂乱无章且无规律,可以放弃。其余的字段包括:Pclass、Sex、Age、SibSp、Parch 和 Fare,这些属性分别表示了乘客的船票等级、性别、年龄、亲戚数量以及船票价格,可能会和乘客的生存预测分类有关系。具体是什么关系,可以交给分类器来处理。

    # 特征选择
    features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']
    train_features = train_data[features]
    train_labels = train_data['Survived']
    test_features = test_data[features]
    dvec = DictVectorizer(sparse=False)
    # fit_transform 函数将特征向量转化为特征值矩阵
    train_features = dvec.fit_transform(train_features.to_dict(orient='record'))
    print(dvec.feature_names_)

    Console Output

    ['Age', 'Embarked=C', 'Embarked=Q', 'Embarked=S', 'Fare', 'Parch', 'Pclass', 'Sex=female', 'Sex=male', 'SibSp']
    (891, 10)

    可以看到原本是一列的 Embarked,变成了“Embarked=C”、“Embarked=Q”、“Embarked=S”三列。Sex 列变成了“Sex=female”、“Sex=male”两列。

    这样 train_features 特征矩阵就包括 10 个特征值(列),以及 891 个样本(行),即 891 行,10 列的特征矩阵。

    Note: fit_transform 和 transform 的区别

    • fit 从一个训练集中学习模型参数,其中就包括了归一化时用到的均值,标准偏差等,可以理解为一个训练过程。
    • transform: 在fit的基础上,对数据进行标准化,降维,归一化等数据转换操作。
    • fit_transform: 将模型训练和转化合并到一起,训练样本先做fit,得到mean,standard deviation,然后将这些参数用于transform(归一化训练数据),使得到的训练数据是归一化的,而测试数据只需要在原先fit得到的mean,std上来做归一化就行了,所以用transform就行了。

    需要注意的是,transform和fit_transform虽然结果相同,但是不能互换。因为fit_transform只是 fit+transform两个步骤合并的简写。而各种分类算法都需要先fit,然后再进行transform。所以如果把fit_transform替换为transform可能会报错。

    建模训练

    # 决策树模型
    # 构造 ID3 决策树
    clf = DecisionTreeClassifier(criterion='entropy')
    
    # 决策树训练
    clf.fit(train_features, train_labels)
    
    # 模型预测评估
    test_features=dvec.transform(test_features.to_dict(orient='record'))
    
    # 决策树预测
    pred_labels = clf.predict(test_features)
    
    # 决策树准确率
    from sklearn.metrics import accuracy_score
    
    train_score = clf.score(train_features, train_labels)
    test_labels = test_data['Survived']
    test_score = accuracy_score(test_labels, pred_labels)
    
    print(u'train_score 准确率为 %.4lf' % train_score)
    print(u'test_score 准确率为 %.4lf' % test_score)

    Console Output

    train_score 准确率为 0.9820
    test_score 准确率为 0.9820

    Note: 

    用训练集做训练,再用训练集自身做准确率评估自然会很高。但这样得出的准确率并不能代表决策树分类器的准确率。因为没有测试集的实际结果,因此无法用测试集的预测结果与实际结果做对比。如果使用 score 函数对训练集的准确率进行统计,正确率会接近于 100%(如上结果为 98.2%),无法对分类器的在实际环境下做准确率的评估。模型准确率需要考虑是否有测试集的实际结果可以做对比,当测试集没有真实结果可以对比时,需要使用 K 折交叉验证 cross_val_score。

    K 折交叉验证

    交叉验证是一种常用的验证分类准确率的方法,原理是拿出大部分样本进行训练,少量的用于分类器的验证。K 折交叉验证,就是做 K 次交叉验证,每次选取 K 分之一的数据作为验证,其余作为训练。轮流 K 次,取平均值。

    K 折交叉验证的原理

    1. 将数据集平均分割成 K 个等份;
    2. 使用 1 份数据作为测试数据,其余作为训练数据;
    3. 计算测试准确率;
    4. 使用不同的测试集,重复 2、3 步骤。

    在 sklearn 的 model_selection 模型选择中提供了 cross_val_score 函数。cross_val_score 函数中的参数 cv 代表对原始数据划分成多少份,也就是 K 值,一般建议 K 值取 10,因此可以设置 CV=10,可以对比下 score 和 cross_val_score 两种函数的正确率的评估结果。

    # K 折交叉验证统计决策树准确率
    cv_score = np.mean(cross_val_score(clf, train_features, train_labels, cv=10))
    print(u'cross_val_score 准确率为 %.4lf' % cv_score)

    Console Output (每次运行结果可能会有不同)

    cross_val_score 准确率为 0.7746

    决策树可视化

    dot_data = tree.export_graphviz(clf, out_file=None)
    graph = graphviz.Source(dot_data)
    graph.view('titanic')

    Note:如果提示 graphviz 不可用或者引用不到等错误,运行如下命令进行安装

    conda install graphviz
    conda install python-graphviz
    conda install pydot

    执行后,可以得到下面的图示

    Reference

    https://github.com/cystanford/Titanic_Data

    https://time.geekbang.org/column/article/79072

    https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.describe.html

    https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.fillna.html

    https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_dict.html

    https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.DictVectorizer.html

    https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html

  • 相关阅读:
    c:forTokens标签循环输出
    jsp转long类型为date,并且格式化
    spring中@Param和mybatis中@Param使用区别(暂时还没接触)
    734. Sentence Similarity 有字典数组的相似句子
    246. Strobogrammatic Number 上下对称的数字
    720. Longest Word in Dictionary 能连续拼接出来的最长单词
    599. Minimum Index Sum of Two Lists两个餐厅列表的索引和最小
    594. Longest Harmonious Subsequence强制差距为1的最长连续
    645. Set Mismatch挑出不匹配的元素和应该真正存在的元素
    409. Longest Palindrome 最长对称串
  • 原文地址:https://www.cnblogs.com/agilestyle/p/12722688.html
Copyright © 2011-2022 走看看