zoukankan      html  css  js  c++  java
  • 评估一个预测模型性能通常都有那些指标

    对于不同类型的模型,会有不同的评估指标,那么我们从最直接的回归和分类这两个类型,对于结果连续的回归问题,

    一般使用的大致为:MSE(均方差),MAE(绝对平均差),RMSE(根均方差)这三种评估方法,这三种方式公式此处补贴出来。

    对于离散的分类问题,我们一般看ROC曲线,以及AUC曲线,一般好的模型,ROC曲线,在一开始就直接上升到1,然后一直保持1,也就是使得AUC=1.0或者尽可能的让其

    接近这个值,这是我们奋斗的目标.

    摘个实际的例子:--出自《预测分析核心算法》这本书.

     1 #-*-coding:utf-8-*-
     2 __author__ ='gxjun'
     3 import pandas as pd
     4 import matplotlib.pyplot as plt
     5 from pandas import DataFrame
     6 from random import uniform
     7 import math 
     8 import numpy as np
     9 import random 
    10 import pylab as pl
    11 from sklearn import datasets,linear_model
    12 from sklearn.metrics import roc_curve ,auc
    13 
    14 
    15 ##计算RP值
    16 def confusionMatrix(predicted ,actual , threshold):
    17     if len(predicted) != len(actual):
    18         return -1;
    19     tp=0.0;
    20     fp=0.0;
    21     tn=0.0;
    22     fn=0.0;
    23     for i in range(len(actual)):
    24         if actual[i] >0.5:
    25             if predicted[i] > threshold:
    26                 tp+=1.0;
    27             else:
    28                 fn+=1.0;
    29         else:
    30             if predicted[i]<threshold:
    31                 tn+=1.0;
    32             else:
    33                 fp+=1.0;
    34     rtn=[tp,tn,fp,fn];
    35     return rtn;
    36 target_url =("https://archive.ics.uci.edu/ml/machine-learning-databases/undocumented/connectionist-bench/sonar/sonar.all-data")
    37 data = pd.read_csv(target_url,header=None,prefix='V');
    38 print('-'*80)
    39 print(data.head())
    40 print('-'*80)
    41 print(data.tail())
    42 print('-'*80)
    43 print(data.describe())
    44 print('-'*80)
    45 label = [];
    46 dataRows = [];
    47 
    48 for i in range(208):
    49     if data.iat[i,-1]=='M':
    50         label.append(1.0);
    51     else:
    52         label.append(0);
    53 print label        
    54 dataRows=data.iloc[:,0:-1];
    55 x_train = np.array(dataRows);
    56 y_train = np.array(label);
    57 print "x_train shape: {} , y_train shape: {}".format(x_train.shape,y_train.shape);
    58 print "x_test shape: {} , y_test shape: {}".format(x_test.shape,y_test.shape);
    59 x_test = np.array(dataRows[0:int(208/3)]);
    60 y_test = np.array(label[0:int(208/3)]);
    61 #train model
    62 rockModel = linear_model.LinearRegression();
    63 rockModel.fit(x_train,y_train);
    64 prob = rockModel.predict(x_train);
    65 print('-'*80);
    66 confusionMatrain = confusionMatrix(prob,y_train,threshold=0.5);
    67 
    68 #print confusionMatrain
    69 fpr ,tpr,threshold = roc_curve(y_train,prob);
    70 roc_auc = auc(fpr,tpr);
    71 
    72 plt.clf();
    73 plt.plot(fpr,tpr,label='ROC curve(area =%0.2f)'%roc_auc);
    74 pl.plot([0,1],[0,1],'k-');
    75 pl.xlim([0.0,1.0]);
    76 pl.ylim([0.0,1.0]);
    77 pl.xlabel("FP rate}");
    78 pl.ylabel("TP rate}");
    79 pl.title("ROC");
    80 pl.legend(loc="lower right");
    81 pl.show()

    结果为:

  • 相关阅读:
    java mail使用qq邮箱发邮件的配置方法
    (利用tempdata判断action是直接被访问还是重定向访问)防止微信活动中用户绕过关注公众号的环节
    判断浏览器为微信浏览器
    解决表单(搜索框)回车的时候直接提交了表单不运行js的问题
    传智播客JavaWeb day11--事务的概念、事务的ACID、数据库锁机制、
    传智播客JavaWeb day10-jdbc操作mysql、连接数据库六大步骤
    页面上常用的一些小功能--QQ、回到顶部
    手机端禁止网页页面放大代码
    Resharp注册码
    NueGet设置package Source
  • 原文地址:https://www.cnblogs.com/gongxijun/p/7449561.html
Copyright © 2011-2022 走看看