gather就是按行取值:
a1 = [[1,2], [3, 4], [5, 6]] a2 = tf.gather(tf.constant(a1), [0, 1]) print(a2)
输出:
tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32)
相当于:
a1[:2]