zoukankan      html  css  js  c++  java
  • tf.concat( )和tf.stack( )

    相同点:都是组合重构数据.

    不同点:concat()不改变维数,而stack改变了维数(待定!!!)

    tf.concat是连接两个矩阵的操作,请注意API版本更改问题,相应参数也发生改变,具体查看API.

    tf.concat(concat_dim, values, name='concat')

    除去name参数用以指定该操作的name,与方法有关的一共两个参数:

    第一个参数concat_dim:必须是一个数,表明在哪一维上连接

         如果concat_dim是0,那么在某一个shape的第一个维度上连,对应到实际,就是叠放到列上 

    1. t1 = [[1, 2, 3], [4, 5, 6]]  
    2. t2 = [[7, 8, 9], [10, 11, 12]]  
    3. tf.concat(0, [t1, t2]) == > [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]  
    t1 = [[1, 2, 3], [4, 5, 6]]
    t2 = [[7, 8, 9], [10, 11, 12]]
    tf.concat(0, [t1, t2]) == > [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

                 如果concat_dim是1,那么在某一个shape的第二个维度上连

    1. t1 = [[1, 2, 3], [4, 5, 6]]  
    2. t2 = [[7, 8, 9], [10, 11, 12]]  
    3. tf.concat(1, [t1, t2]) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12  
    t1 = [[1, 2, 3], [4, 5, 6]]
    t2 = [[7, 8, 9], [10, 11, 12]]
    tf.concat(1, [t1, t2]) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12

                 如果有更高维,最后连接的依然是指定那个维:

                 values[i].shape = [D0, D1, ... Dconcat_dim(i), ...Dn]连接后就是:[D0, D1, ... Rconcat_dim, ...Dn]

      

    1. # tensor t3 with shape [2, 3]  
    2. # tensor t4 with shape [2, 3]  
    3. tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]  
    4. tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]  
    # tensor t3 with shape [2, 3]
    # tensor t4 with shape [2, 3]
    tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]
    tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]

    第二个参数values:就是两个或者一组待连接的tensor了

     

    /×××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××××/

    这里要注意的是:如果是两个向量,它们是无法调用  

     
    1. tf.concat(1, [t1, t2])  
    tf.concat(1, [t1, t2])

    来连接的,因为它们对应的shape只有一个维度,当然不能在第二维上连了,虽然实际中两个向量可以在行上连,但是放在程序里是会报错的

    如果要连,必须要调用tf.expand_dims来扩维: 

    1. t1=tf.constant([1,2,3])  
    2. t2=tf.constant([4,5,6])  
    3. #concated = tf.concat(1, [t1,t2])这样会报错  
    4. t1=tf.expand_dims(tf.constant([1,2,3]),1)  
    5. t2=tf.expand_dims(tf.constant([4,5,6]),1)  
    6. concated = tf.concat(1, [t1,t2])#这样就是正确的  
  • 相关阅读:
    部署yearning1.3
    git常用指令
    U盘centos7系统安装http://www.augsky.com/599.html
    C语言与SQL SERVER数据库(转)
    C连接MySQL数据库开发之Windows环境配置及测试(转)
    vs2012中添加lib,.h文件方法(原)
    如何用Visual Studio 2013 (vs2013)编写C语言程序 (转)
    Java值传递以及引用的传递、数组的传递!!
    ssh整合需要那些jar
    类加载器
  • 原文地址:https://www.cnblogs.com/mdumpling/p/8053474.html
Copyright © 2011-2022 走看看