zoukankan      html  css  js  c++  java
  • 混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)

    原理

      在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能. 混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量. 这个矩阵的每一行表示真实类中的实例, 而每一列表示预测类中的实例 (Tensorflow 和 scikit-learn 采用的实现方式). 也可以是, 每一行表示预测类中的实例, 而每一列表示真实类中的实例 (Confusion matrix From Wikipedia 中的定义). 通过混淆矩阵, 可以很容易看出系统是否会弄混两个类, 这也是混淆矩阵名字的由来.

      混淆矩阵是一种特殊类型的列联表(contingency table)或交叉制表(cross tabulation or crosstab). 其有两维 (真实值 "actual" 和 预测值 "predicted" ), 这两维都具有相同的类("classes")的集合. 在列联表中, 每个维度和类的组合是一个变量. 列联表以表的形式, 可视化地表示多个变量的频率分布. 

    使用混淆矩阵( scikit-learn 和 Tensorflow)

      下面先介绍在 scikit-learn 和 tensorflow 中计算混淆矩阵的 API (Application Programming Interface) 接口函数, 然后在一个示例中, 使用这两个 API 函数.

    scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口

    skearn.metrics.confusion_matrix(
        y_true,   # array, Gound true (correct) target values
        y_pred,  # array, Estimated targets as returned by a classifier
        labels=None,  # array, List of labels to index the matrix.
        sample_weight=None  # array-like of shape = [n_samples], Optional sample weights
    )

    在 scikit-learn 中, 计算混淆矩阵用来评估分类的准确度.

      按照定义, 混淆矩阵 C 中的元素 Ci,j 等于真实值为组 i , 而预测为组 j 的观测数(the number of observations). 所以对于二分类任务, 预测结果中, 正确的负例数(true negatives, TN)为 C0,0; 错误的负例数(false negatives, FN)为 C1,0; 真实的正例数为 C1,1; 错误的正例数为 C0,1.

      如果 labels 为 None, scikit-learn 会把在出现在 y_true 或 y_pred 中的所有值添加到标记列表 labels 中, 并排好序. 

    Tensorflow 混淆矩阵函数 tf.confusion_matrix API 接口

    tf.confusion_matrix(
        labels,   # 1-D Tensor of real labels for the classification task
        predictions,   # 1-D Tensor of predictions for a givenclassification
        num_classes=None,  #  The possible number of labels the classification task can have
        dtype=tf.int32,   # Data type of the confusion matrix 
        name=None,    # Scope name
        weights=None,    # An optional Tensor whose shape matches predictions
    )

      Tensorflow tf.confusion_matrix 中的 num_classes 参数的含义, 与 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 参数相近, 是与标记有关的参数, 表示类的总个数, 但没有列出具体的标记值. 在 Tensorflow 中一般是以整数作为标记, 如果标记为字符串等非整数类型, 则需先转为整数表示. 如果 num_classes 参数为 None, 则把 labels 和 predictions 中的最大值 + 1, 作为 num_classes 参数值.

      tf.confusion_matrix 的 weights 参数和 sklearn.metrics.confusion_matrix 的 sample_weight 参数的含义相同, 都是对预测值进行加权, 在此基础上, 计算混淆矩阵单元的值.

    使用示例

    #!/usr/bin/env python
    # -*- coding: utf8 -*-
    """
    Author: klchang
    Description:
      A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.
    Date: 2018.9.8
    """
    from __future__ import print_function import tensorflow as tf import sklearn.metrics y_true = [1, 2, 4] y_pred = [2, 2, 4] # Build graph with tf.confusion_matrix operation sess = tf.InteractiveSession() op = tf.confusion_matrix(y_true, y_pred) op2 = tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3])) # Execute the graph print ("confusion matrix in tensorflow: ") print ("1. default: ", op.eval()) print ("2. customed: ", sess.run(op2))
    sess.close()
    # Use sklearn.metrics.confusion_matrix function print (" confusion matrix in scikit-learn: ") print ("1. default: ", sklearn.metrics.confusion_matrix(y_true, y_pred)) print ("2. customed: ", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))

    参考资料

    1. Confusion matrix. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Confusion_matrix

    2. Contingency table. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Contingency_table

    3. Tensorflow API - tf.confusion_matrix. https://www.tensorflow.com/api_docs/python/tf/confusion_matrix

    4.  scikit-learn API - sklearn.metrics.confusion_matrix. http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html

  • 相关阅读:
    Blank page instead of the SharePoint Central Administration site
    BizTalk 2010 BAM Configure
    Use ODBA with Visio 2007
    Handling SOAP Exceptions in BizTalk Orchestrations
    BizTalk与WebMethods之间的EDI交换
    Append messages in BizTalk
    FTP protocol commands
    Using Dynamic Maps in BizTalk(From CodeProject)
    Synchronous To Asynchronous Flows Without An Orchestration的简单实现
    WSE3 and "Action for ultimate recipient is required but not present in the message."
  • 原文地址:https://www.cnblogs.com/klchang/p/9608412.html
Copyright © 2011-2022 走看看