zoukankan      html  css  js  c++  java
  • tf.nn.embedding_lookup函数的用法

    关于np.random.RandomState、np.random.rand、np.random.random、np.random_sample参考https://blog.csdn.net/lanchunhui/article/details/50405670

    tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。tf.nn.embedding_lookup(params, ids):params可以是张量也可以是数组等,id就是对应的索引,其他的参数不介绍。

    例如:

    ids只有一行:

    #c = np.random.random([10, 1])  # 随机生成一个10*1的数组
    #b = tf.nn.embedding_lookup(c, [1, 3])#查找数组中的序号为1和3的
    p=tf.Variable(tf.random_normal([10,1]))#生成10*1的张量
    b = tf.nn.embedding_lookup(p, [1, 3])#查找张量中的序号为1和3的
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(b))
        #print(c)
        print(sess.run(p))
        print(p)
        print(type(p))
    

      

    输出:

    [[0.15791859]
     [0.6468804 ]]
    [[-0.2737084 ]
     [ 0.15791859]
     [-0.01315552]
     [ 0.6468804 ]
     [-1.4090979 ]
     [ 2.1583703 ]
     [ 1.4137447 ]
     [ 0.20688428]
     [-0.32815856]
     [-1.0601649 ]]
    <tf.Variable 'Variable:0' shape=(10, 1) dtype=float32_ref>
    <class 'tensorflow.python.ops.variables.Variable'>
    

     分析:输出为张量的第一和第三个元素。

    如果ids是多行:

    import tensorflow as tf
    import numpy as np
    
    a = [[0.1, 0.2, 0.3], [1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3], [4.1, 4.2, 4.3]]
    a = np.asarray(a)
    idx1 = tf.Variable([0, 2, 3, 1], tf.int32)
    idx2 = tf.Variable([[0, 2, 3, 1], [4, 0, 2, 2]], tf.int32)
    out1 = tf.nn.embedding_lookup(a, idx1)
    out2 = tf.nn.embedding_lookup(a, idx2)
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        print sess.run(out1)
        print out1
        print '=================='
        print sess.run(out2)
        print out2

    输出:

    [[ 0.1  0.2  0.3]
     [ 2.1  2.2  2.3]
     [ 3.1  3.2  3.3]
     [ 1.1  1.2  1.3]]
    Tensor("embedding_lookup:0", shape=(4, 3), dtype=float64)
    ==================
    [[[ 0.1  0.2  0.3]
      [ 2.1  2.2  2.3]
      [ 3.1  3.2  3.3]
      [ 1.1  1.2  1.3]]
    
     [[ 4.1  4.2  4.3]
      [ 0.1  0.2  0.3]
      [ 2.1  2.2  2.3]
      [ 2.1  2.2  2.3]]]
    Tensor("embedding_lookup_1:0", shape=(2, 4, 3), dtype=float64)
    参考链接:https://www.jianshu.com/p/ad88a0afa98f
    
  • 相关阅读:
    后台性能测试不可不知的二三事
    linux下操作mysql
    loadrunner scripts
    反射
    java 读取json
    java 多线程
    python_day11
    python爬虫
    python_day10
    python_day9
  • 原文地址:https://www.cnblogs.com/gaofighting/p/9625868.html
Copyright © 2011-2022 走看看