zoukankan      html  css  js  c++  java
  • tensorflow对多维tensor按照指定索引重排序

    背景是这样的,

    比如我有一个张量data,shape是(batch_size,100,128)

    我还有一个张量inc,shape是(batch_size,100)

    我现在想根据这个张量地索引来对data重排序。

    为什么会有这样地需求呢,是因为比如data是数据,100代表数据步长,128代表数据内units数目(维度),inc代表一个分数,这个分数表明了这100个步长当中每一步的重要性。现在我想要对data重排序一下,取top10,变成(batch_size,10,128),这样有利于后面的Attention。

    操作例子见代码:

    最主要的思想就是你有一个N维向量,那么就要指定一个N-1维的索引来对其重排序。例子中我们是一个(batch_size,100,128)的数据,

    那么如果:

    data是(batch_size,A,B,C,100,128)

    inc是(batch_size,A,B,C,100,128)呢?

    我的想法是先data reshape成(batch_size*A*B*C,100,128)

    inc reshape成(batch_size*A*B*C,100)

    后面的操作就一样了,先unstack,分别用gather取出相应切片(其实这里就已经做了个排序)

    然后再stack回去

    import tensorflow as tf
    import numpy as np
    
    data = tf.placeholder(tf.int64, [None, 5, 2])
    
    choose = tf.placeholder(tf.int64,[None,5])
    sortarg = tf.argsort(choose, direction="DESCENDING")
    split_data = tf.unstack(data, num=3, axis=0)
    split_choose = tf.unstack(sortarg, num=3, axis=0)
    trans_data_list = list()
    for i in range(3):
        trans_data_list.append(tf.gather(split_data[i], sortarg[i]))
    trans_data = tf.stack(trans_data_list, axis=0)
    
    
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        feed_dict = {
            choose:[[5,4,3,0,1],[2,3,0,4,2],[2,3,5,4,2]],
            data:[[[1,2],[3,4],[5,6],[7,8],[9,10]], [[11,12],[13,14],[15,16],[17,18],[19,20]], [[21,22],[23,24],[25,26],[27,28],[29,30]]]
        }
        print(sess.run(sortarg,feed_dict=feed_dict))
        print("-----------------------------------------------------")
        # print(sess.run(data_trans,feed_dict = feed_dict))
        print(sess.run(data,feed_dict=feed_dict))
        print("-----------------------------------------------------")
        print(sess.run(trans_data, feed_dict=feed_dict))
    

      

  • 相关阅读:
    Task 5.1 电梯调度程序需求调研报告
    Task 4 求数组的连续子数组的最大和(团队合作)
    class 3 求数组中的最大值(单元测试)
    《你的灯亮着吗》读书笔记3
    优惠购书
    校友聊NABCD
    环状二维数组(改进版)
    环状二维数组
    环状一维数组
    二维数组最大值
  • 原文地址:https://www.cnblogs.com/zhouxiaosong/p/12203119.html
Copyright © 2011-2022 走看看