zoukankan      html  css  js  c++  java
  • keras multi-label classification 多标签分类

    问题:一个数据又多个标签,一个样本数据多个类别中的某几类;比如一个病人的数据有多个疾病,一个文本有多种题材,所以标签就是: [1,0,0,0,1,0,1] 这种高维稀疏类型,如何计算分类准确率?

    分类问题:

    二分类

    多分类

    多标签

    Keras metrics (性能度量)

    介绍的比较好的一个博客:

    https://machinelearningmastery.com/custom-metrics-deep-learning-keras-python/

    还有一个介绍loss的博客:

    https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/

    metrics:在训练的每个batch结束的时候计算训练集acc,如果提供验证集(一个epoch结束计算验证集acc),也同时计算验证集的性能度量,分为回归任务和分类任务,有不同的acc计算办法;metrics 里面可以放 loss (回归问题)或者acc(分类问题);

    A metric is a function that is used to judge the performance of your model.

    A metric function is similar to a loss function, except that the results from evaluating a metric are not used when training the model. You may use any of the loss functions as a metric function.

    metrics其实和loss类似,只是不用来指导网络的训练;一般根据具体问题具体要求采用不同的 metric 函数,衡量性能;

    分类问题的不同acc计算方法:

    • Binary Accuracy: binary_accuracy, acc
    • Categorical Accuracy: categorical_accuracy, acc
    • Sparse Categorical Accuracy: sparse_categorical_accuracy
    • Top k Categorical Accuracy: top_k_categorical_accuracy (requires you specify a k parameter)
    • Sparse Top k Categorical Accuracy: sparse_top_k_categorical_accuracy (requires you specify a k parameter)

    keras metrics 默认的accuracy:

    metrics["accuracy"] :   == categorical_accuracy; 最快的验证方法,训练一个简单网络,同时输出默认accuracy,categorical_accuracy,,binaray_accuracy, 对比就可以知道;

    或者看keras源码,找到metrics默认设置:

      

    多标签分类问题:

    [1,0,0,1,0] , [1,0,0,0,0] 分别是 y_pred, y_true:

    如果使用 binary_accuracy : acc = 0.8;

    if the prediction would be [0, 0, 0, 0, 0, 1]. And if the actual labels were [0, 0, 0, 0, 0, 0], the accuracy would be 5/6.;

    训练过程常见坑:

    1.自定义loss:

    自定义loss写成函数的时候,keras compile() 里面,要调用自定义的loss函数而不是只给函数名:

    model.compile(optimizer="adam", loss=self_defined_loss(), metrics=["accuracy"])
     
    2. 关于top5 , top1 ACC:--(针对多分类不是多标签问题)
    一个图片可能是 [猫,狗,大象,老鼠,小皮球,房子]里面的一种;我们对每个图片输出一个概率分布 [0.3,0.2,0.1,0.1,0.3] , 如果:
    top1: 概率最高的预测类别是否和真实标签一致;
    top5:概率最高的5个预测类别是否包含了真实标签;
     
  • 相关阅读:
    POJ 3140 Contestants Division (树dp)
    POJ 3107 Godfather (树重心)
    POJ 1655 Balancing Act (树的重心)
    HDU 3534 Tree (经典树形dp)
    HDU 1561 The more, The Better (树形dp)
    HDU 1011 Starship Troopers (树dp)
    Light oj 1085
    Light oj 1013
    Light oj 1134
    FZU 2224 An exciting GCD problem(GCD种类预处理+树状数组维护)同hdu5869
  • 原文地址:https://www.cnblogs.com/robin2ML/p/11560786.html
Copyright © 2011-2022 走看看