zoukankan      html  css  js  c++  java
  • tensorflow 笔记7:tf.concat 和 ops中的array_ops.concat

    用于连接两个矩阵:

    mn = array_ops.concat([a, d], 1) #  按照第二维度相接,shape1 [m,a] shape2 [m,b] ,concat_done shape : [m,a+b]

    tensorflow Rnn,Lstm,Gru,源码中是用以上的函数来链接Xt 和 Ht-1 的,两者的shape 分别是【batch_size, emb_size】【batch_size,Hidden_size】

    连接接后为的shape为:【batch_size,embedding_size + Hidden_size】,作为当前时刻的输入;

    测试代码:

     1 import os
     2 import tensorflow as tf
     3 import numpy as np
     4 import sys
     5 from tensorflow.python.ops import array_ops
     6 #array_ops.concat([inputs, state], 1)
     7 
     8 a = tf.constant([[1,12,8,6], [3,4,6,7]])  # shape [2,4]
     9 b = tf.constant([[10, 20,6,88], [30,40,7,8]]) # shape [2,4]
    10 c = tf.constant([[10, 20,6,88,99], [30,40,7,8,15]]) #shape [2,5]
    11 d = tf.constant([[10,20,6,88], [30,40,7,8],[30,40,7,8]]) # shape [3,4]
    12 nn = tf.concat([a, d],0) # 按照第一维度相接,shape1 [a,m] shape2 [b,m] concat_done:[a+b,m]
    13 nn_1 = tf.concat([a, c],1) # 按照第二维度相接,shape1 [m,a] shape2 [m,b] concat_done:[m,a+b]
    14 mn = array_ops.concat([a, d], 0) # 按照第一维度相接,shape1 [a,m] shape2 [b,m] concat_done:[a+b,m]
    15 mn_1 = array_ops.concat([a, c], 1) # 按照第二维度相接,shape1 [m,a] shape2 [m,b] concat_done:[m,a+b]
    16 
    17 with tf.Session() as sess:
    18      print (nn)
    19      print (nn.eval())
    20      print (nn_1)
    21      print (nn_1.eval())
    22      print (mn)
    23      print (mn.eval())   # shape [5,4]
    24      print (mn_1)
    25      print (mn_1.eval()) # shape [2,9]

    结果输出:

    Tensor("concat:0", shape=(5, 4), dtype=int32)
    [[ 1 12  8  6]
     [ 3  4  6  7]
     [10 20  6 88]
     [30 40  7  8]
     [30 40  7  8]]
    Tensor("concat_1:0", shape=(2, 9), dtype=int32)
    [[ 1 12  8  6 10 20  6 88 99]
     [ 3  4  6  7 30 40  7  8 15]]
    Tensor("concat_2:0", shape=(5, 4), dtype=int32)
    [[ 1 12  8  6]
     [ 3  4  6  7]
     [10 20  6 88]
     [30 40  7  8]
     [30 40  7  8]]
    Tensor("concat_3:0", shape=(2, 9), dtype=int32)
    [[ 1 12  8  6 10 20  6 88 99]
     [ 3  4  6  7 30 40  7  8 15]]

  • 相关阅读:
    错题
    URL和URI区别
    适配器
    JAVA 反射机制
    JAVA 面试题
    JAVA 继承
    多态 JAVA
    Java面向对象编辑
    [LeetCode] Merge k Sorted Lists
    [LeetCode] Valid Palindrome
  • 原文地址:https://www.cnblogs.com/lovychen/p/9367099.html
Copyright © 2011-2022 走看看