zoukankan      html  css  js  c++  java
  • 【笔记】精准率和召回率的平衡以及精准率召回率曲线

    精准率和召回率的平衡以及精准率召回率曲线

    精准率和召回率的平衡

    前文书道,使用F1 score可以同时考虑到精准率和召回率,那么,怎么能同时让二者都变大呢,可惜,是不行的,因为这两个指标是互相矛盾的,一方提高,一方就会下降,因此要找到两者的平衡

    可以通过逻辑回归决策边界来理解

    对于决策边界来说,在逻辑回归中是设置为0的线,这里可以设置为一个常量threshold,此时就可以将方程写成

    大于threshold分为1,小于threshold分为0,这样也可以形成决策边界,这样就引入了一个新的超参数,相当于将原先的决策边界进行了平移操作,可以通过调整分类域值来对精准率和召回率进行调整,使其增大减小

    大概可以作为这样

    具体实现

    (在notebook中)

    熟悉的环境部署,前面部分都在这里有

    使用逻辑回归算法并用训练出的数据进行预测

      from sklearn.linear_model import LogisticRegression
    
      log_reg = LogisticRegression()
      log_reg.fit(X_train,y_train)
      y_predict = log_reg.predict(X_test)
    

    使用sklearn中的F1 score来计算并求出准确率

      from sklearn.metrics import f1_score
    
      f1_score(y_test,y_predict)
    

    结果如下

    得出混淆矩阵

      from sklearn.metrics import confusion_matrix
    
      confusion_matrix(y_test,y_predict)
    

    结果如下

    得出精准率

      from sklearn.metrics import precision_score
    
      precision_score(y_test,y_predict)
    

    结果如下

    得出召回率

      from sklearn.metrics import recall_score
    
      recall_score(y_test,y_predict)
    

    结果如下

    如果想调整分类域值的话,可以使用decision_function,进行决策

    观察前十个样本对应的逻辑算法中的score值

      log_reg.decision_function(X_test)[:10]
    

    结果如下

    取出前十个预测结果

      log_reg.predict(X_test)[:10]
    

    结果如下(由于默认为0分类,所以都为0)

    设decision_scores来保存分数值

      decision_scores = log_reg.decision_function(X_test)
    

    其中的最小值

      np.min(decision_scores)
    

    结果如下

    相应的最大值

      np.max(decision_scores)
    

    结果如下

    设置新的threshold为5,大于等于5的分类为1,小于的分为0

      y_predict_2 = np.array(decision_scores >= 5 , dtype='int')
    

    此时的混淆矩阵

    confusion_matrix(y_test,y_predict_2)

    结果如下

    此时的精准率

      precision_score(y_test,y_predict_2)
    

    结果如下

    此时的召回率

      recall_score(y_test,y_predict_2)
    

    结果如下

    设置新的threshold为-5,大于等于-5的分类为1,小于的分为0

      y_predict_3 = np.array(decision_scores >= -5 , dtype='int')
    

    此时的混淆矩阵

      confusion_matrix(y_test,y_predict_3)
    

    结果如下

    此时的精准率

      precision_score(y_test,y_predict_3)
    

    结果如下

    此时的召回率

      recall_score(y_test,y_predict_3)
    

    结果如下

    通过修改域值来将精准率和召回率增大减小,这就说明了这两个值是相互影响的,那么使用可视化的方法来让这个更加的直观的展现出来

    精准率召回率曲线

    具体实现

    (在notebook中)

    求出decision_scores中的最大值以及最小值,设置一个数组,最小值和最大值分别为数组的起点和终点,设置步长为0.1,将其称为thresholds,然后就要求出对应的每一个值的精准率和召回率,绘制出对应的位置,设置两个列表用来存放精准率和召回率

    设置一个循环,对于每一个threshold,都进行一次新的预测,方式就是让decision_scores大于等于这一次循环中的threshold,将其设置为int型,得到预测值以后,对精准率和召回率进行求解,将其放入两个列表中,这样就得到了一个随着threshold不断变化,精准率和召回率也不断变化的一个操作,使用这个可以很容易的绘制出两根曲线

      from sklearn.metrics import precision_score
      from sklearn.metrics import recall_score
    
      precisions = []
      recalls = []
      thresholds = np.arange(np.min(decision_scores),np.max(decision_scores),0.1)
      for threshold in thresholds:
          y_predict = np.array(decision_scores >= threshold,dtype='int')
          precisions.append(precision_score(y_test,y_predict))
          recalls.append(recall_score(y_test,y_predict))
    
      plt.plot(thresholds,precisions)
      plt.plot(thresholds,recalls)
    

    图像如下

    绘制出精准率和召回率的线相关曲线

      plt.plot(precisions,recalls)
    

    图像如下(在突然下降的点前很有可能就是两者最好的平衡点)

    在sklearn中调用精准率和召回率的曲线,在metrics中使用precision_recall_curve这个类即可,传入参数就是真值以及分类用的scores值的数组,返回三个值

      from sklearn.metrics import precision_recall_curve
    
      precisions, recalls, thresholds =  precision_recall_curve(y_test,decision_scores)
    

    这个算法会根据传入的值自动的选取最合适的步长
    这三个值的最好的结果如下

    绘制曲线,由于精准率和召回率的值比thresholds的值多一个,所以要去掉一个值

      plt.plot(thresholds,precisions[:-1])
      plt.plot(thresholds,recalls[:-1])
    

    图像如下

    通过这个图还是可以看到这两个值互相制约的情况,而且通过这个图就可以找到需要的部分的精准率和召回率的范围选择相应的域值

    对于Precision-recall曲线来说,整体的趋势

    同时,每一次通过不同的参数得到的曲线(比原来好的),整体有一个模型的曲线更靠外,通常就更好,有的时候,与坐标轴构成的面积越大越好的方式更容易理解

  • 相关阅读:
    AJAX
    大前端面试一(基础)
    webpack打包vue -->简易讲解
    vue实现原理
    Spring boot 线上部署
    javascript 事件
    React native采坑路 Running 1 of 1 custom shell scripts
    PHP swoole实现redis订阅和发布
    JAVA 注解和反射
    微信公众平台获取用户地理位置之开发文档详解
  • 原文地址:https://www.cnblogs.com/jokingremarks/p/14325277.html
Copyright © 2011-2022 走看看