zoukankan      html  css  js  c++  java
  • 『TensorFlow』one_hot化标签

    tf.one_hot(indices, depth):将目标序列转换成one_hot编码

    tf.one_hot
    (indices, depth, on_value=None, off_value=None, 
    axis=None, dtype=None, name=None)

    indices = [0, 2, -1, 1]
    depth = 3
    on_value = 5.0 
    off_value = 0.0 
    axis = -1 
    #Then output is [4 x 3]: 
    output = 
    [5.0 0.0 0.0] // one_hot(0) 
    [0.0 0.0 5.0] // one_hot(2) 
    [0.0 0.0 0.0] // one_hot(-1) 
    [0.0 5.0 0.0] // one_hot(1)

    with tf.Session() as sess:
      print(sess.run(tf.one_hot(np.array([np.array([0,1,2,3]),np.array([2,0,3,2])]),depth=4,axis=-1)))
    
    # [[[ 1.  0.  0.  0.]
    #    [ 0.  1.  0.  0.]
    #    [ 0.  0.  1.  0.]
    #    [ 0.  0.  0.  1.]]
    #   [[ 0.  0.  1.  0.]
    #    [ 1.  0.  0.  0.]
    #    [ 0.  0.  0.  1.]
    #    [ 0.  0.  1.  0.]]]
    
    
    oh = tf.one_hot(indices = [0, 2, -1, 1], depth = 3,  on_value = 5.0 , off_value = 0.0, axis = -1)
    sess = tf.Session()
    sess.run(oh)
    
    # array([[5., 0., 0.],
    #        [0., 0., 5.],
    #        [0., 0., 0.],
    #        [0., 5., 0.]], dtype=float32)
    

    另一种思路:稀疏张量构建法

    import numpy as np
    import tensorflow as tf
    
    NUMCLASS = 3
    batch_size = 5
    
    labels = tf.placeholder(dtype=tf.int32, shape=[batch_size, 1])
    index = tf.reshape(tf.range(0, batch_size,1), [batch_size, 1])
    one_hot = tf.sparse_to_dense(
                                 tf.concat(values=[index, labels], axis=1),
                                 [batch_size, NUMCLASS],
                                 1.0, 0.0
                                 )
    with tf.Session() as sess:
        lab = np.random.randint(0,3,[5,1])
        print(sess.run(one_hot, feed_dict={labels:lab}))
        print(sess.run(tf.one_hot(np.squeeze(lab),depth=3,axis=1)))
    

    注意两种方法输入数据维度的变化(稀疏法为了得到足够的索引需要升维),结果如下:

    [[ 1.  0.  0.]
     [ 1.  0.  0.]
     [ 0.  0.  1.]
     [ 1.  0.  0.]
     [ 0.  1.  0.]]
    [[ 1.  0.  0.]
     [ 1.  0.  0.]
     [ 0.  0.  1.]
     [ 1.  0.  0.]
     [ 0.  1.  0.]]
  • 相关阅读:
    MYSQL学习(二)
    Nginx学习总结(一)
    关于微服务架构的个人理解(一)
    深入理解Java虚拟机(二) : 垃圾回收
    深入理解Java虚拟机(一) 运行时数据区划分
    多线程系列之 线程安全
    多线程系列之 java多线程的个人理解(二)
    多线程系列之 Java多线程的个人理解(一)
    Java基础04—字符串
    Java基础03—流程控制
  • 原文地址:https://www.cnblogs.com/hellcat/p/8568253.html
Copyright © 2011-2022 走看看