https://blog.csdn.net/qq_40652148/article/details/80467131
https://yq.aliyun.com/articles/602111
git 代码:
https://blog.csdn.net/CodeMaster_/article/details/76223835
https://github.com/TracyMcgrady6/Distribute_MNIST/blob/master/distributed.py
tf.nn.embedding_lookup函数解释:
https://www.jianshu.com/p/91a5de231b90
举个例子:
import numpy as np
import tensorflow as tf
data = np.array([[[2],[1]],[[3],[4]],[[6],[7]]])
data = tf.convert_to_tensor(data)
lk = [[0,1],[1,0],[0,0]]
lookup_data = tf.nn.embedding_lookup(data,lk)
init = tf.global_variables_initializer()
先让我们看下不同数据对应的维度:
In [76]: data.shape
Out[76]: (3, 2, 1)
In [77]: np.array(lk).shape
Out[77]: (3, 2)
In [78]: lookup_data
Out[78]: <tf.Tensor 'embedding_lookup_8:0' shape=(3, 2, 2, 1) dtype=int64>
这个是怎么做到的呢?关键的部分来了,看下图:
lk中的值,在要寻找的embedding数据中下找对应的index下的vector进行拼接。永远是look(lk)部分的维度+embedding(data)部分的除了第一维后的维度拼接。很明显,我们也可以得到,lk里面值是必须要小于等于embedding(data)的最大维度减一的。