zoukankan      html  css  js  c++  java
  • 【火炉炼AI】机器学习016-如何知道SVM模型输出类别的置信度

    【火炉炼AI】机器学习016-如何知道SVM模型输出类别的置信度

    (本文所使用的Python库和版本号: Python 3.5, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2 )

    一般的,对于未知样本,我们通过模型预测出来属于某种类别,往往会给出是这种类别的概率。

    比如通过AI模型识别某一种图片是“狗”的概率是95.8%,是”猫“的概率是4.2%,那用SVM能不能得到类似的属于某种类别的概率值了?


    1. 准备数据集

    本部分代码和上一篇文章(【火炉炼AI】机器学习014-用SVM构建非线性分类模型 )几乎一样,故此不再赘述。


    2. 计算新样本的置信度

    此处,我们自己构建了非线性SVM分类模型,并使用该模型对新样本数据进行类别的分类。如下代码

    # 计算某个新样本的置信度
    new_samples=np.array([[2,1.5],
                          [8,9],
                          [4.8,5.2],
                          [4,4],
                          [2.5,7],
                          [7.6,2],
                          [5.4,5.9]])
    classifier3=SVC(kernel='rbf',probability=True) # 比上面的分类器增加了 probability=True参数
    classifier3.fit(train_X,train_y)
    
    # 使用训练好的SVM分类器classifier3对新样本进行预测,并给出置信度
    for sample in new_samples:
        print('sample: {}, probs: {}'.format(sample,classifier3.predict_proba([sample])[0]))
    

    -------------------------------------输---------出--------------------------------

    sample: [2. 1.5], probs: [0.08066588 0.91933412]
    sample: [8. 9.], probs: [0.08311977 0.91688023]
    sample: [4.8 5.2], probs: [0.14367183 0.85632817]
    sample: [4. 4.], probs: [0.06178594 0.93821406]
    sample: [2.5 7. ], probs: [0.21050117 0.78949883]
    sample: [7.6 2. ], probs: [0.07548128 0.92451872]
    sample: [5.4 5.9], probs: [0.45817727 0.54182273]

    --------------------------------------------完-------------------------------------

    将该新样本数据点绘制到2D分布图中,可以得到如下结果。

    使用rbf类型的SVM分类器得到的分类效果

    ########################小**********结###############################

    1. 从中可以看出,如果要输出每一个类别的不同概率,需要设置参数probability=True,同时,需要用classifier.predict_proba()函数来获取类别概率值。

    2. 模型输出的样本类别的概率,就是该样本属于这个类别的置信度。

    #################################################################


    注:本部分代码已经全部上传到(我的github)上,欢迎下载。

    参考资料:

    1, Python机器学习经典实例,Prateek Joshi著,陶俊杰,陈小莉译

  • 相关阅读:
    【1】Chrome
    Vue
    GitHub版本控制工具入门(一)
    Vue.js 组件笔记
    最全的javascriptt选择题整理
    网站如何实现 在qq中发自己链接时,便自动获取链接标题、图片和部分内容
    js 唤起APP
    密码加密MD5,Bash64
    HTTP和HTTPS的区别及HTTPS加密算法
    计算机网络七层的理解
  • 原文地址:https://www.cnblogs.com/RayDean/p/9765699.html
Copyright © 2011-2022 走看看