zoukankan      html  css  js  c++  java
  • 计算auc-python/awk

    1.自己写的计算auc的代码,用scikit-learn的auc计算函数sklearn.metrics.auc(xyreorder=False)做了一些测试,结果是一样的,如有错误,欢迎指正。

    思路:1.首先对预测值进行排序,排序的方式用了python自带的函数sorted,详见注释。

       2.对所有样本按照预测值从小到大标记rank,rank其实就是index+1,index是排序后的sorted_pred数组中的索引

       3.将所有正样本的rank相加,遇到预测值相等的情况,不管样本的正负性,对rank要取平均值再相加

            4.将rank相加的和减去正样本排在正样本之后的情况,再除以总的组合数,得到auc

    
    
     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Wed May  3 10:48:28 2017
     4 
     5 @author: Vincent
     6 """
     7 import numpy as np
     8 from sklearn import metrics
     9 y = np.array(   [1,     0,  0,   1,   1,  1,  0,  1,  1,  1])
    10 pred = np.array([0.9, 0.9,0.8, 0.8, 0.7,0.7,0.7,0.6,0.5,0.4])
    11 fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=1)
    12 print(metrics.auc(fpr, tpr))
    13 def getAuc(labels, pred) :
    14     '''将pred数组的索引值按照pred[i]的大小正序排序,返回的sorted_pred是一个新的数组,
    15        sorted_pred[0]就是pred[i]中值最小的i的值,对于这个例子,sorted_pred[0]=8
    16     '''
    17     sorted_pred = sorted(range(len(pred)), key = lambda i : pred[i])
    18     pos = 0.0 #正样本个数
    19     neg = 0.0 #负样本个数
    20     auc = 0.0 
    21     last_pre = pred[sorted_pred[0]]
    22     count = 0.0
    23     pre_sum = 0.0  #当前位置之前的预测值相等的rank之和,rank是从1开始的,所以在下面的代码中就是i+1
    24     pos_count = 0.0  #记录预测值相等的样本中标签是正的样本的个数
    25     for i in range(len(sorted_pred)) :
    26         if labels[sorted_pred[i]] > 0:
    27             pos += 1        
    28         else:
    29             neg += 1       
    30         if last_pre != pred[sorted_pred[i]]: #当前的预测概率值与前一个值不相同
    31             #对于预测值相等的样本rank需要取平均值,并且对rank求和
    32             auc += pos_count * pre_sum / count  
    33             count = 1          
    34             pre_sum = i + 1     #更新为当前的rank    
    35             last_pre = pred[sorted_pred[i]] 
    36             if labels[sorted_pred[i]] > 0:
    37                 pos_count = 1   #如果当前样本是正样本 ,则置为1
    38             else:
    39                 pos_count = 0   #反之置为0
    40         else:
    41             pre_sum += i + 1    #记录rank的和
    42             count += 1          #记录rank和对应的样本数,pre_sum / count就是平均值了
    43             if labels[sorted_pred[i]] > 0:#如果是正样本
    44                 pos_count += 1  #正样本数加1
    45     auc += pos_count * pre_sum / count #加上最后一个预测值相同的样本组
    46     auc -= pos *(pos + 1) / 2 #减去正样本在正样本之前的情况
    47     auc = auc / (pos * neg)  #除以总的组合数
    48     return auc
    49 print(getAuc(y, pred))

     2.awk代码

     1 #计算auc,输入分别为预测值(可以乘以一个倍数之后转化为整数),该相同预测值的样本个数,该相同预测值的正样本个数
     2 sort -t $'	' -k 1,1n | awk -F"	" 'BEGIN{
     3     OFS="	";
     4     now_q="";
     5     begin_rank=1;
     6     now_pos_num=0;
     7     now_neg_num=0;
     8     total_pos_rank=0;
     9     total_pos_num=0;
    10     total_neg_num=0;
    11 }function clear(){
    12     begin_rank += now_pos_num + now_neg_num;
    13     now_pos_num=0;
    14     now_neg_num=0;
    15 }function update(){
    16     now_pos_num += pos_num;
    17     now_neg_num += neg_num;
    18 }function output(){
    19     n = now_pos_num + now_neg_num;
    20     avg_rank = begin_rank + (n-1)/2;
    21     tmp_all_pos_rank = avg_rank * now_pos_num;
    22     total_pos_rank += tmp_all_pos_rank;
    23     total_pos_num += now_pos_num;
    24     total_neg_num += now_neg_num;
    25 }{
    26     q=$1;
    27     show=$2;
    28     clk=$3;
    29     pos_num=clk;
    30     neg_num=show-clk;
    31     if(now_q!=q){
    32         if(now_q!=""){
    33             output();
    34             clear();
    35         }
    36         now_q=q;
    37     }
    38     update();
    39 
    40 }END{
    41     output();
    42     auc=0;
    43     m=total_pos_num;
    44        n=total_neg_num;
    45     if(m>0 && n>0){
    46         auc = (total_pos_rank-m*(m+1)/2) / (m*n);
    47     }
    48     print auc;
    49 }'
  • 相关阅读:
    Ubuntu 安装 JDK 7 / JDK8 的两种方式
    python 深拷贝 浅拷贝 赋值
    importlib.import_module
    pandas分块读取大量数据集
    win10下安装XGBoost Gpu版本
    win10下安装LGBM GPU版本
    统计自然语言处理(第二版)笔记1
    K-近邻算法
    2019考研的一些心得
    lib和dll的区别与使用
  • 原文地址:https://www.cnblogs.com/fisherinbox/p/6806164.html
Copyright © 2011-2022 走看看