zoukankan      html  css  js  c++  java
  • keras multi-label classification 多标签分类

    问题:一个数据又多个标签,一个样本数据多个类别中的某几类;比如一个病人的数据有多个疾病,一个文本有多种题材,所以标签就是: [1,0,0,0,1,0,1] 这种高维稀疏类型,如何计算分类准确率?

    分类问题:

    二分类

    多分类

    多标签

    Keras metrics (性能度量)

    介绍的比较好的一个博客:

    https://machinelearningmastery.com/custom-metrics-deep-learning-keras-python/

    还有一个介绍loss的博客:

    https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/

    metrics:在训练的每个batch结束的时候计算训练集acc,如果提供验证集(一个epoch结束计算验证集acc),也同时计算验证集的性能度量,分为回归任务和分类任务,有不同的acc计算办法;metrics 里面可以放 loss (回归问题)或者acc(分类问题);

    A metric is a function that is used to judge the performance of your model.

    A metric function is similar to a loss function, except that the results from evaluating a metric are not used when training the model. You may use any of the loss functions as a metric function.

    metrics其实和loss类似,只是不用来指导网络的训练;一般根据具体问题具体要求采用不同的 metric 函数,衡量性能;

    分类问题的不同acc计算方法:

    • Binary Accuracy: binary_accuracy, acc
    • Categorical Accuracy: categorical_accuracy, acc
    • Sparse Categorical Accuracy: sparse_categorical_accuracy
    • Top k Categorical Accuracy: top_k_categorical_accuracy (requires you specify a k parameter)
    • Sparse Top k Categorical Accuracy: sparse_top_k_categorical_accuracy (requires you specify a k parameter)

    keras metrics 默认的accuracy:

    metrics["accuracy"] :   == categorical_accuracy; 最快的验证方法,训练一个简单网络,同时输出默认accuracy,categorical_accuracy,,binaray_accuracy, 对比就可以知道;

    或者看keras源码,找到metrics默认设置:

      

    多标签分类问题:

    [1,0,0,1,0] , [1,0,0,0,0] 分别是 y_pred, y_true:

    如果使用 binary_accuracy : acc = 0.8;

    if the prediction would be [0, 0, 0, 0, 0, 1]. And if the actual labels were [0, 0, 0, 0, 0, 0], the accuracy would be 5/6.;

    训练过程常见坑:

    1.自定义loss:

    自定义loss写成函数的时候,keras compile() 里面,要调用自定义的loss函数而不是只给函数名:

    model.compile(optimizer="adam", loss=self_defined_loss(), metrics=["accuracy"])
     
    2. 关于top5 , top1 ACC:--(针对多分类不是多标签问题)
    一个图片可能是 [猫,狗,大象,老鼠,小皮球,房子]里面的一种;我们对每个图片输出一个概率分布 [0.3,0.2,0.1,0.1,0.3] , 如果:
    top1: 概率最高的预测类别是否和真实标签一致;
    top5:概率最高的5个预测类别是否包含了真实标签;
     
  • 相关阅读:
    某公司面试的SQL题目
    列存储索引
    JList动态添加元素
    Java中堆、栈、常量池等概念解析
    JButton大小设置问题?
    JAVA中定时器的使用
    线性表和链表的区别
    JTable表头显示问题以及如何让某行选中
    JPanel如何设置背景图片
    关于Scanner调用nextInt()异常try后不能二次输入问题
  • 原文地址:https://www.cnblogs.com/robin2ML/p/11560786.html
Copyright © 2011-2022 走看看