zoukankan      html  css  js  c++  java
  • tf.nn.embedding_lookup

    tf.nn.embedding_lookup记录

    96 

    tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。tf.nn.embedding_lookup(tensor, id):tensor就是输入张量,id就是张量对应的索引,其他的参数不介绍。

    例如:

    import tensorflow as tf;
    import numpy as np;

    c = np.random.random([10,1])
    b = tf.nn.embedding_lookup(c, [1, 3])

    with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print sess.run(b)
    print c
    输出:
    [[ 0.77505197]
     [ 0.20635818]]
    [[ 0.23976515]
     [ 0.77505197]
     [ 0.08798201]
     [ 0.20635818]
     [ 0.37183035]
     [ 0.24753178]
     [ 0.17718483]
     [ 0.38533808]
     [ 0.93345168]
     [ 0.02634772]]

    分析:输出为张量的第一和第三个元素。
    ---------------------
    作者:UESTC_C2_403
    来源:CSDN
    原文:https://blog.csdn.net/uestc_c2_403/article/details/72779417
    版权声明:本文为博主原创文章,转载请附上博文链接!

    slade_sal 
    2018.06.11 10:33* 字数 375 阅读 5099评论 0
     

    我觉得这张图就够了,实际上tf.nn.embedding_lookup的作用就是找到要寻找的embedding data中的对应的行下的vector。

    tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None, validate_indices=True, max_norm=None)
    

    官方文档位置,其中,params是我们给出的,可以通过:
    1.tf.get_variable("item_emb_w", [self.item_count, self.embedding_size])等方式生产服从[0,1]的均匀分布或者标准分布
    2.tf.convert_to_tensor转化我们现有的array
    然后,ids是我们要找的params中对应位置。

    举个例子:

    import numpy as np
    import tensorflow as tf
    data = np.array([[[2],[1]],[[3],[4]],[[6],[7]]])
    data = tf.convert_to_tensor(data)
    lk = [[0,1],[1,0],[0,0]]
    lookup_data = tf.nn.embedding_lookup(data,lk)
    init = tf.global_variables_initializer()
    

    先让我们看下不同数据对应的维度:

    In [76]: data.shape
    Out[76]: (3, 2, 1)
    In [77]: np.array(lk).shape
    Out[77]: (3, 2)
    In [78]: lookup_data
    Out[78]: <tf.Tensor 'embedding_lookup_8:0' shape=(3, 2, 2, 1) dtype=int64>
    

    这个是怎么做到的呢?关键的部分来了,看下图:

     

    lk中的值,在要寻找的embedding数据中下找对应的index下的vector进行拼接。永远是look(lk)部分的维度+embedding(data)部分的除了第一维后的维度拼接。很明显,我们也可以得到,lk里面值是必须要小于等于embedding(data)的最大维度减一的。

    以上的结果就是:

    In [79]: data
    Out[79]:
    array([[[2],
            [1]],
    
           [[3],
            [4]],
    
           [[6],
            [7]]])
    
    In [80]: lk
    Out[80]: [[0, 1], [1, 0], [0, 0]]
    
    # lk[0]也就是[0,1]对应着下面sess.run(lookup_data)的结果恰好是把data中的[[2],[1]],[[3],[4]]
    
    In [81]: sess.run(lookup_data)
    Out[81]:
    array([[[[2],
             [1]],
    
            [[3],
             [4]]],
    
    
           [[[3],
             [4]],
    
            [[2],
             [1]]],
    
    
           [[[2],
             [1]],
    
            [[2],
             [1]]]])
    

    最后,partition_strategy是用于当len(params) > 1,params的元素分割不能整分的话,则前(max_id + 1) % len(params)多分一个id.
    当partition_strategy = 'mod'的时候,13个ids划分为5个分区:[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]],也就是是按照数据列进行映射,然后再进行look_up操作。
    当partition_strategy = 'div'的时候,13个ids划分为5个分区:[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]],也就是是按照数据先后进行排序标序,然后再进行look_up操作。

    萍水相逢逢萍水,浮萍之水水浮萍!
  • 相关阅读:
    Excel基础—文件菜单之创建保存
    Excel技巧—名称框的妙用
    Excel基础—文件菜单之设置信息
    Excel基础—文件菜单之打印共享账户
    Excel基础—文件菜单之设置选项
    Excel基础—工作界面概述
    linux环境下pathinfo 工作失败的改进函数
    javascript为网页元素绑定click事件
    将纯真ip数据库解析并导入mysql数据库中
    pgsql导入和导出数据
  • 原文地址:https://www.cnblogs.com/AIBigTruth/p/10349174.html
Copyright © 2011-2022 走看看