zoukankan      html  css  js  c++  java
  • one-hot 编码


    def onehot(labels):
      '''one-hot 编码'''
      #数据有几行输出
      n_sample = len(labels)
      #数据分为几类。因为编码从0开始所以要加1
      n_class = max(labels) + 1
      #建立一个batch所需要的数组,全部赋0.
      onehot_labels = np.zeros((n_sample, n_class))
      #对每一行的,对应分类赋1
      onehot_labels[np.arange(n_sample), labels] = 1
      return onehot_labels

    运行结果:

    label=np.array([0,1,2])

    onehot(label)
    Out[8]:
    array([[ 1., 0., 0.],
    [ 0., 1., 0.],
    [ 0., 0., 1.]])

    label=np.array([0,4,7,1,1,1,4,1])

    onehot(label)
    Out[10]:
    array([[ 1., 0., 0., 0., 0., 0., 0., 0.],
    [ 0., 0., 0., 0., 1., 0., 0., 0.],
    [ 0., 0., 0., 0., 0., 0., 0., 1.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.],
    [ 0., 0., 0., 0., 1., 0., 0., 0.],
    [ 0., 1., 0., 0., 0., 0., 0., 0.]])

    总结:本次标签只有一类,如第一个标签为一类,有两种情况。第二个为标签一类,有七种情况。如果标签为两类,比如{男生,女生}、{一年级、二年级、三年级},那编码的长度为5.

    onehot标签则是顾名思义,一个长度为n的数组,只有一个元素是1.0,其他元素是0.0。

    想想为什么要这样编码,知乎大佬的的一个解释感觉极有道理。

    使用onehot的直接原因是现在多分类cnn网络的输出通常是softmax层,而它的输出是一个概率分布,从而要求输入的标签也以概率分布的形式出现,进而算交叉熵之类。
    onehot其实就是给出了,真是的样本真实概率分布,其中一个样本数据概率为1,其他全为0.。计算损失交叉熵时,直接用1*log(1/概率),就直接算出了交叉熵,作为损失。



  • 相关阅读:
    前端性能优化——写给网页设计师和前端工程师看的
    V8引擎——详解
    Perl_实用报表提取语言
    qs.stringify和JSON.stringify()
    js之history
    js考察this,作用域链和闭包
    css table之合并单元格
    js手机浏览器浏览WebApp弹出的键盘遮盖住文本框的解决办法
    windows 杀进程
    axios库的使用
  • 原文地址:https://www.cnblogs.com/smartwhite/p/8950600.html
Copyright © 2011-2022 走看看