zoukankan      html  css  js  c++  java
  • one-hot 编码


    def onehot(labels):
      '''one-hot 编码'''
      #数据有几行输出
      n_sample = len(labels)
      #数据分为几类。因为编码从0开始所以要加1
      n_class = max(labels) + 1
      #建立一个batch所需要的数组,全部赋0.
      onehot_labels = np.zeros((n_sample, n_class))
      #对每一行的,对应分类赋1
      onehot_labels[np.arange(n_sample), labels] = 1
      return onehot_labels

    运行结果:

    label=np.array([0,1,2])

    onehot(label)
    Out[8]:
    array([[ 1., 0., 0.],
    [ 0., 1., 0.],
    [ 0., 0., 1.]])

    label=np.array([0,4,7,1,1,1,4,1])

    onehot(label)
    Out[10]:
    array([[ 1., 0., 0., 0., 0., 0., 0., 0.],
    [ 0., 0., 0., 0., 1., 0., 0., 0.],
    [ 0., 0., 0., 0., 0., 0., 0., 1.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.],
    [ 0., 0., 0., 0., 1., 0., 0., 0.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.]])

    总结:本次标签只有一类,如第一个标签为一类,有两种情况。第二个为标签一类,有七种情况。如果标签为两类,比如{男生,女生}、{一年级、二年级、三年级},那编码的长度为5.

    onehot标签则是顾名思义,一个长度为n的数组,只有一个元素是1.0,其他元素是0.0。

    想想为什么要这样编码,知乎大佬的的一个解释感觉极有道理。

    使用onehot的直接原因是现在多分类cnn网络的输出通常是softmax层,而它的输出是一个概率分布,从而要求输入的标签也以概率分布的形式出现,进而算交叉熵之类。
    onehot其实就是给出了,真是的样本真实概率分布,其中一个样本数据概率为1,其他全为0.。计算损失交叉熵时,直接用1*log(1/概率),就直接算出了交叉熵,作为损失。



  • 相关阅读:
    Cat- Linux必学的60个命令
    Cmp- Linux必学的60个命令
    Diff- Linux必学的60个命令
    ls- Linux必学的60个命令
    mv- Linux必学的60个命令
    Find- Linux必学的60个命令
    libvirt
    PHP 设计模式 笔记与总结(2)开发 PSR-0 的基础框架
    Java实现 LeetCode 147 对链表进行插入排序
    Java实现 LeetCode 146 LRU缓存机制
  • 原文地址:https://www.cnblogs.com/smartwhite/p/8950600.html
Copyright © 2011-2022 走看看