zoukankan      html  css  js  c++  java
  • Multi label 多标签分类问题(Pytorch,TensorFlow,Caffe)

    适用场景:一个输入对应多个label,或输入类别间不互斥

    调用函数:

    1. Pytorch使用torch.nn.BCEloss

    2. Tensorflow使用tf.losses.sigmoid_cross_entropy

    3. Caffe使用SigmoidCrossEntropyLoss

    在output和target之间构建binary cross entropy,其中i为每一个类。

     

    以pytorch为例:Caffe,TensorFlow版本类比,输入均为相同形式的向量

    m = nn.Sigmoid()
    loss = nn.BCELoss()
    input = autograd.Variable(torch.randn(3), requires_grad=True)
    target = autograd.Variable(torch.FloatTensor(3).random_(2))
    output = loss(m(input), target)
    output.backward()

    注意target的形式,要写成01编码形式,eg:如果同时为第一类和第三类则,[1, 0, 1]

    主要是结合sigmoid来使用,经过classifier分类过后的输出为(batch_size,num_class)为每个数据的标签, 标签不是one-hot的主要体现在sigmoid输出之后,仍然为(batch_size,num_class),对于一个实例,它的各个label的分数加起来不一定等于1,bceloss在每个类维度上求cross entropy loss然后加和求平均得到,这里就体现了多标签的思想。

    [CVPR2015] Is object localization for free? – Weakly-supervised learning with convolutional neural networks这篇论文里设计了针对多标签问题的loss,传统的类别分类不适用,作者把这个任务视为多个二分类问题,loss function和分类的分数如下:

     

  • 相关阅读:
    2018.5.17 memcached
    2018.5.11 B树总结
    2018.5.8 排序总结
    2018.5.8 python操纵sqlite数据库
    2018.5.4 Unix的五种IO模型
    2018.5.3 maven
    2018.5.3 docker
    Mybatis学习笔记,挺全的!
    这么强大的Mybatis插件机制原来就是这?
    Swagger API Spec + Swagger Codegen + YAPI管理接口文档
  • 原文地址:https://www.cnblogs.com/demian/p/9674204.html
Copyright © 2011-2022 走看看