zoukankan      html  css  js  c++  java
  • 利用鸢尾花数据集绘制P-R曲线

     1 #利用鸢尾花数据集绘制P-R曲线
     2 print(__doc__)      #打印注释
     3  
     4 import matplotlib.pyplot as plt
     5 import numpy as np
     6 from sklearn import svm, datasets
     7 from sklearn.metrics import precision_recall_curve
     8 from sklearn.metrics import average_precision_score
     9 from sklearn.preprocessing import label_binarize
    10 from sklearn.multiclass import OneVsRestClassifier  #一对其余(每次将一个类作为正类,剩下的类作为负类)
    11  
    12 from sklearn.cross_validation import train_test_split  #适用于anaconda 3.6及以前版本
    13 #from sklearn.model_selection import train_test_split#适用于anaconda 3.7
    14  
    15 #以iris数据为例,画出P-R曲线
    16 iris = datasets.load_iris()
    17 X = iris.data    #150*4
    18 y = iris.target  #150*1
    19  
    20 # 标签二值化,将三个类转为001, 010, 100的格式.因为这是个多类分类问题,后面将要采用
    21 #OneVsRestClassifier策略转为二类分类问题
    22 y = label_binarize(y, classes=[0, 1, 2])    #将150*1转化成150*3
    23 n_classes = y.shape[1]                      #列的个数,等于3
    24 print (y)
    25  
    26 # 增加了800维的噪声特征
    27 random_state = np.random.RandomState(0)
    28 n_samples, n_features = X.shape
    29  
    30 X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]   #行不变,只增加了列,150*804
    31  
    32 # 训练集和测试集拆分,比例为0.5
    33 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=random_state) #随机数,填0或不填,每次都会不一样
    34  
    35 # 一对其余,转换成两类,构建新的分类器
    36 classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=random_state))
    37 #训练集送给fit函数进行拟合训练,训练完后将测试集的样本特征注入,得到测试集中每个样本预测的分数
    38 y_score = classifier.fit(X_train, y_train).decision_function(X_test)
    39  
    40 # Compute Precision-Recall and plot curve  
    41 #下面的下划线是返回的阈值。作为一个名称:此时“_”作为临时性的名称使用。
    42 #表示分配了一个特定的名称,但是并不会在后面再次用到该名称。
    43 precision = dict()
    44 recall = dict()
    45 average_precision = dict()
    46 for i in range(n_classes):
    47     #对于每一类,计算精确率和召回率的序列(:表示所有行,i表示第i列)
    48     precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],  y_score[:, i]) 
    49     average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])#切片,第i个类的分类结果性能
    50  
    51 # Compute micro-average curve and area. ravel()将多维数组降为一维
    52 precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),  y_score.ravel())
    53 average_precision["micro"] = average_precision_score(y_test, y_score, average="micro") #This score corresponds to the area under the precision-recall curve.
    54  
    55 # Plot Precision-Recall curve for each class
    56 plt.clf()#clf 函数用于清除当前图像窗口
    57 plt.plot(recall["micro"], precision["micro"],
    58          label='micro-average Precision-recall curve (area = {0:0.2f})'.format(average_precision["micro"]))
    59 for i in range(n_classes):
    60     plt.plot(recall[i], precision[i],
    61              label='Precision-recall curve of class {0} (area = {1:0.2f})'.format(i, average_precision[i]))
    62  
    63 plt.xlim([0.0, 1.0])
    64 plt.ylim([0.0, 1.05]) #xlim、ylim:分别设置X、Y轴的显示范围。
    65 plt.xlabel('Recall', fontsize=16)
    66 plt.ylabel('Precision',fontsize=16)
    67 plt.title('Extension of Precision-Recall curve to multi-class',fontsize=16)
    68 plt.legend(loc="lower right")#legend 是用于设置图例的函数
    69 plt.show()

    运行结果如下:

  • 相关阅读:
    JS中使用正则表达式封装的一些常用的格式验证的方法-是否外部url、是否小写、邮箱格式、是否字符、是否数组
    Java中操作字符串的工具类-判空、截取、格式化、转换驼峰、转集合和list、是否包含
    Cocos2d-x 2.0 自适应多种分辨率
    应用自定义移动设备外观
    为移动设备应用程序创建外观
    【2020-11-28】人生十三信条
    【2020-11-27】事实证明,逃避是下等策略
    Python 之web动态服务器
    Python 之pygame飞机游戏
    PHP 之转换excel表格中的经纬度
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13018923.html
Copyright © 2011-2022 走看看