zoukankan      html  css  js  c++  java
  • 机器学习之路: python 决策树分类DecisionTreeClassifier 预测泰坦尼克号乘客是否幸存

    使用python3 学习了决策树分类器的api

    涉及到 特征的提取,数据类型保留,分类类型抽取出来新的类型

    需要网上下载数据集,我把他们下载到了本地,

    可以到我的git下载代码和数据集: https://github.com/linyi0604/MachineLearning

     1 import pandas as pd
     2 from sklearn.cross_validation import train_test_split
     3 from sklearn.feature_extraction import DictVectorizer
     4 from sklearn.tree import DecisionTreeClassifier
     5 from sklearn.metrics import classification_report
     6 
     7 '''
     8 决策树
     9 涉及多个特征,没有明显的线性关系
    10 推断逻辑非常直观
    11 不需要对数据进行标准化
    12 '''
    13 
    14 '''
    15 1 准备数据
    16 '''
    17 # 读取泰坦尼克乘客数据,已经从互联网下载到本地
    18 titanic = pd.read_csv("./data/titanic/titanic.txt")
    19 # 观察数据发现有缺失现象
    20 # print(titanic.head())
    21 
    22 # 提取关键特征,sex, age, pclass都很有可能影响是否幸免
    23 x = titanic[['pclass', 'age', 'sex']]
    24 y = titanic['survived']
    25 # 查看当前选择的特征
    26 # print(x.info())
    27 '''
    28 <class 'pandas.core.frame.DataFrame'>
    29 RangeIndex: 1313 entries, 0 to 1312
    30 Data columns (total 3 columns):
    31 pclass    1313 non-null object
    32 age       633 non-null float64
    33 sex       1313 non-null object
    34 dtypes: float64(1), object(2)
    35 memory usage: 30.9+ KB
    36 None
    37 '''
    38 # age数据列 只有633个,对于空缺的 采用平均数或者中位数进行补充 希望对模型影响小
    39 x['age'].fillna(x['age'].mean(), inplace=True)
    40 
    41 '''
    42 2 数据分割
    43 '''
    44 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=33)
    45 # 使用特征转换器进行特征抽取
    46 vec = DictVectorizer()
    47 # 类别型的数据会抽离出来 数据型的会保持不变
    48 x_train = vec.fit_transform(x_train.to_dict(orient="record"))
    49 # print(vec.feature_names_)   # ['age', 'pclass=1st', 'pclass=2nd', 'pclass=3rd', 'sex=female', 'sex=male']
    50 x_test = vec.transform(x_test.to_dict(orient="record"))
    51 
    52 '''
    53 3 训练模型 进行预测
    54 '''
    55 # 初始化决策树分类器
    56 dtc = DecisionTreeClassifier()
    57 # 训练
    58 dtc.fit(x_train, y_train)
    59 # 预测 保存结果
    60 y_predict = dtc.predict(x_test)
    61 
    62 '''
    63 4 模型评估
    64 '''
    65 print("准确度:", dtc.score(x_test, y_test))
    66 print("其他指标:
    ", classification_report(y_predict, y_test, target_names=['died', 'survived']))
    67 '''
    68 准确度: 0.7811550151975684
    69 其他指标:
    70               precision    recall  f1-score   support
    71 
    72        died       0.91      0.78      0.84       236
    73    survived       0.58      0.80      0.67        93
    74 
    75 avg / total       0.81      0.78      0.79       329
    76 '''
  • 相关阅读:
    【Linux开发】Linux下jpeglib库的安装详解
    【Linux开发】Linux下jpeglib库的安装详解
    【Linux开发】jpeglib使用指南
    【Linux开发】jpeglib使用指南
    【Linux开发】为qt-embedded添加jpeg库的交叉编译方法for arm
    【Linux开发】为qt-embedded添加jpeg库的交叉编译方法for arm
    Windows 7 64bit上安装Oracle Database 12c [INS-30131] 错误的解决方法
    Log4j 日志记录
    如何根据Ip获取地址信息--Java----待整理完善!!!
    Struts如何获取客户端ip地址
  • 原文地址:https://www.cnblogs.com/Lin-Yi/p/8970609.html
Copyright © 2011-2022 走看看