zoukankan      html  css  js  c++  java
  • [转]tensorflow中的gather

    原文链接
    tensorflow中取下标的函数包括:tf.gather , tf.gather_nd 和 tf.batch_gather。

    1.tf.gather(params,indices,validate_indices=None,name=None,axis=0)

    indices必须是一维张量
    主要参数:

    • params:被索引的张量
    • indices:一维索引张量
    • name:返回张量名称

    返回值:通过indices获取params下标的张量。
    例子:

    import tensorflow as tf
    tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
    tensor_b = tf.Variable([1,2,0],dtype=tf.int32)
    tensor_c = tf.Variable([0,0],dtype=tf.int32)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(tf.gather(tensor_a,tensor_b)))
        print(sess.run(tf.gather(tensor_a,tensor_c)))
    

    上个例子tf.gather(tensor_a,tensor_b) 的值为[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值为[[1,2,3],[1,2,3]]

    对于tensor_a,其第1个元素为[4,5,6],第2个元素为[7,8,9],第0个元素为[1,2,3],所以以[1,2,0]为索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],同样的,以[0,0]为索引的值为[[1,2,3],[1,2,3]]。

    https://www.tensorflow.org/api_docs/python/tf/gather

    2.tf.gather_nd(params,indices,name=None)

    功能和参数与tf.gather类似,不同之处在于tf.gather_nd支持多维度索引,即indices可以使多维张量。
    例子:

    import tensorflow as tf
    tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
    tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32)
    tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(tf.gather_nd(tensor_a,tensor_b)))
        print(sess.run(tf.gather_nd(tensor_a,tensor_c)))
    tf.gather_nd(tensor_a,tensor_b)值为[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值为[3,7].
    

    对于tensor_a,下标[1,0]的元素为4,下标为[1,1]的元素为5,下标为[1,2]的元素为6,索引[1,0],[1,1],[1,2]]的返回值为[4,5,6],同样的,索引[[0,2],[2,0]]的返回值为[3,7].

    https://www.tensorflow.org/api_docs/python/tf/gather_nd

    3.tf.batch_gather(params,indices,name=None)

    支持对张量的批量索引,各参数意义见(1)中描述。注意因为是批处理,所以indices要有和params相同的第0个维度。

    例子:

    import tensorflow as tf
    tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
    tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
    tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
        print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
    tf.gather_nd(tensor_a,tensor_b)值为[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值为[1,4,7].
    

    tensor_a的三个元素[1,2,3],[4,5,6],[7,8,9]分别对应索引元素的第一,第二和第三个值。[1,2,3]的第0个元素为1,[4,5,6]的第1个元素为5,[7,8,9]的第2个元素为9,所以索引[[0],[1],[2]]的返回值为[1,5,9],同样地,索引[[0],[0],[0]]的返回值为[1,4,7].

    https://www.tensorflow.org/api_docs/python/tf/batch_gather

    在深度学习的模型训练中,有时候需要对一个batch的数据进行类似于tf.gather_nd的操作,但tensorflow中并没有tf.batch_gather_nd之类的操作,此时需要tf.map_fn和tf.gather_nd结合来实现上述操作。

  • 相关阅读:
    边框颜色为 tintColor 的 UIButton
    iPhone: 在 iPhone app 里使用 UIPopoverController
    CCSprite: fade 效果切换图片
    iOS7以上: 实现如“日历”的 NavigationBar
    jqgrid自适应宽度
    NFine框架学习
    sql server 查看数据库文件的大小
    eclipse没有server选项
    MySql的数据导入到Sql Server数据库中
    eclispe javaw.exe in your current path的解决方法
  • 原文地址:https://www.cnblogs.com/weiyinfu/p/10607301.html
Copyright © 2011-2022 走看看