zoukankan      html  css  js  c++  java
  • 第三章 模型搭建和评估-评估

    import pandas as pd
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    from IPython.display import Image
    from sklearn.linear_model import LogisticRegression
    from sklearn.ensemble import RandomForestClassifier
    %matplotlib inline
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    plt.rcParams['figure.figsize'] = (10, 6)  # 设置输出图片大小
    

    任务:加载数据并分割测试集和训练集

    from sklearn.model_selection import train_test_split
    
    # 一般先取出X和y后再切割,有些情况会使用到未切割的,这时候X和y就可以用,x是清洗好的数据,y是我们要预测的存活数据'Survived'
    data = pd.read_csv('clear_data.csv')
    train = pd.read_csv('train.csv')
    X = data
    y = train['Survived']
    
    # 对数据集进行切割
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
    
    # 默认参数逻辑回归模型
    lr = LogisticRegression()
    lr.fit(X_train, y_train)
    

    请添加图片描述

    模型评估

    • 模型评估是为了知道模型的泛化能力。
    • 交叉验证(cross-validation)是一种评估泛化性能的统计学方法,它比单次划分训练集和测试集的方法更加稳定、全面。
    • 在交叉验证中,数据被多次划分,并且需要训练多个模型。
    • 最常用的交叉验证是 k 折交叉验证(k-fold cross-validation),其中 k 是由用户指定的数字,通常取 5 或 10。
    • 准确率(precision)度量的是被预测为正例的样本中有多少是真正的正例
    • 召回率(recall)度量的是正类样本中有多少被预测为正类
    • f-分数是准确率与召回率的调和平均
      任务一:交叉验证
    • 用10折交叉验证来评估之前的逻辑回归模型
    • 计算交叉验证精度的平均值
    • 请添加图片描述
      提示4
      交叉验证在sklearn中的模块为sklearn.model_selection
    from sklearn.model_selection import cross_val_score
    lr = LogisticRegression(C=100)
    scores = cross_val_score(lr, X_train, y_train, cv=10)
    

    请添加图片描述

    # k折交叉验证分数
    scores
    

    请添加图片描述

    # 平均交叉验证分数
    print("Average cross-validation score: {:.2f}".format(scores.mean()))
    

    请添加图片描述
    思考4
    k折越多的情况下会带来什么样的影响?
    #思考回答
    一般而言,k折越多,评估结果的稳定性和保真性越高,不过整个计算复杂度越高。一种特殊的情况是k=m,m为数据集样本个数,这种特例称为留一法,结果往往比较准确

    任务二:混淆矩阵

    • 计算二分类问题的混淆矩阵
    • 计算精确率、召回率以及f-分数
      【思考】什么是二分类问题的混淆矩阵,理解这个概念,知道它主要是运算到什么任务中的
      #思考回答

    混淆矩阵含义 以及一些参数的计算
    请添加图片描述

    请添加图片描述
    提示5

    • 混淆矩阵的方法在sklearn中的sklearn.metrics模块
    • 混淆矩阵需要输入真实标签和预测标签
    • 精确率、召回率以及f-分数可使用classification_report模块
    from sklearn.metrics import confusion_matrix
    # 训练模型
    lr = LogisticRegression(C=100)
    lr.fit(X_train, y_train)
    

    请添加图片描述

    pred = lr.predict(X_train)
    confusion_matrix(y_train, pred)
    cm = metrics.confusion_matrix(y_train,pred,labels = [0,1])
    from sklearn.metrics import classification_report
    # 精确率、召回率以及f1-score
    print(classification_report(y_train, pred))
    

    请添加图片描述

    sns.heatmap(confusion_matrix(y_train, pred),annot = True,fm2 = '.2e',cmap = 'GnBu')
    sns.heatmap(cm,annot=True, fmt = '.2e',cmap = 'GnBu')
    plt.show()
    
    

    请添加图片描述
    【思考】

    如果自己实现混淆矩阵的时候该注意什么问题

    #回答 https://blog.csdn.net/u011587322/article/details/80660978

    任务三:ROC曲线

    绘制ROC曲线
    【思考】什么是ROC曲线,OCR曲线的存在是为了解决什么问题?

    #思考
    在分类模型中,ROC曲线和AUC值经常作为衡量一个模型拟合程度的指标。最近在建模过程中需要作出模型的ROC曲线


    提示6

    • ROC曲线在sklearn中的模块为sklearn.metrics
    • ROC曲线下面所包围的面积越大越好
    #写入代码
    
    from sklearn.metrics import roc_curve, auc 
    fpr,tpr,threholds = roc_curve(y_test,lr.decision_function(X_test))
    plt.plot(fpr,tpr,label = "ROC Curve")
    plt.xlabel("FPR")
    plt.ylabel("TPR(recall)")
    # 找到最接近于0的阈值 ???
    close_zero = np.argmin(np.abs(threholds))  
    plt.plot(fpr[close_zero],tpr[close_zero], 'o', markersize = 7, label = "threholds zero",fillstyle="none",c='k',mew=2)
    plt.legend(loc=4)
    

    请添加图片描述

    问题总结:搞清楚这些问题

    1. 分层抽样,这样的好处?
    2. 什么情况下切割数据集的时候不用进行随机选取
    3. 为什么线性模型可以进行分类任务,背后是怎么的数学关系
    4. 对于多分类问题,线性模型是怎么进行分类的
    5. 预测标签的概率对我们有什么帮助
    6. k折越多的情况下会带来什么样的影响?
    7. 如果自己实现混淆矩阵的时候该注意什么问题
    8. 对于多分类问题如何绘制ROC曲线
  • 相关阅读:
    learning scala view collection
    scala
    learning scala dependency injection
    learning scala implicit class
    learning scala type alise
    learning scala PartialFunction
    learning scala Function Recursive Tail Call
    learning scala Function Composition andThen
    System.Threading.Interlocked.CompareChange使用
    System.Threading.Monitor的使用
  • 原文地址:https://www.cnblogs.com/most-silence/p/15495345.html
Copyright © 2011-2022 走看看