zoukankan      html  css  js  c++  java
  • tensorflow 的tf.where详解

    最近在用到数据筛选,观看代码中有tf.where()的用法,不是很常用,也不是很好理解。在这里记录一下

    1 tf.where(
    2     condition,
    3     x=None,
    4     y=None,
    5     name=None
    6 )

    Return the elements, either from x or y, depending on the condition.

    理解:where嘛,就是要根据条件找到你要的东西。

    condition:条件,是一个boolean

    x:数据

    y:同x维度的数据。

    返回,返回符合条件的数据。当条件为真,取x对应的数据;当条件为假,取y对应的数据

    举例子。

     1 def test_where():
     2     # 定义一个tensor,表示condition,内部数据随机产生
     3     condition = tf.convert_to_tensor(np.random.random([5]), dtype=tf.float32)
     4 
     5     # 定义两个tensor,表示原数据
     6     a = tf.ones(shape=[5, 3], name='a')
     7 
     8     b = tf.zeros(shape=[5, 3], name='b')
     9 
    10     # 选择大于0.5的数值的坐标,并根据condition信息在a和b中选取数据
    11     result = tf.where(condition > 0.5, a, b)
    12 
    13     with tf.Session() as sess:
    14         print("condition:
    ", sess.run([condition, result]))

    结果:

  • 相关阅读:
    如何向线程传递参数
    IntelliJ IDEA 13 Keygen
    单链表的基本操作
    顺序表静态查找
    有向图的十字链表表存储表示
    BF-KMP 算法
    图的邻接表存储表示(C)
    二叉树的基本操作(C)
    VC远控(三)磁盘显示
    Android 数独游戏 记录
  • 原文地址:https://www.cnblogs.com/demo-deng/p/10209605.html
Copyright © 2011-2022 走看看