zoukankan      html  css  js  c++  java
  • tf.concat()函数解析(最清晰的解释)

    欢迎关注WX公众号:【程序员管小亮】

    tf.concat()函数用于数组或者矩阵拼接。

    tf.concat的官方解释

    tf.concat(
        values,
        axis,
        name='concat'
    )
    

    其中:
    values应该是一个tensor的list或者tuple,里面是准备连接的矩阵或者数组。

    axis则是我们准备连接的矩阵或者数组的维度。

    • axis=0代表在第0个维度拼接
    • axis=1代表在第1个维度拼接
    • axis=-1表示在倒数第1个维度拼接

    负数在数组索引里面表示倒数,也就算是倒着数,-1是最后一个,-2是倒数第二个,对于二维矩阵拼接来说,axis=-1等价于axis=1。一般在维度非常高的情况下,if 我们想在最’高’的维度进行拼接,一般就直接用倒数机制,直接axis=-1就搞定了。

    1. values:

    import tensorflow as tf
    
    t1=tf.constant([1,2,3])
    t2=tf.constant([4,5,6])
    print(t1)
    print(t2)
    concated = tf.concat([t1,t2], 1)
    
    > Tensor("Const_20:0", shape=(3,), dtype=int32)
    > Tensor("Const_21:0", shape=(3,), dtype=int32)
    > ValueError: Shapes (2, 3) and () are incompatible 
    

    因为它们对应的shape只有一个维度,当然不能在第二维上拼接了,虽然实际中两个向量可以在行上(axis = 1)拼接,但是放在程序里是会报错的

    import tensorflow as tf
    
    t1=tf.expand_dims(tf.constant([1,2,3]),1)
    t2=tf.expand_dims(tf.constant([4,5,6]),1)
    print(t1)
    print(t2)
    concated = tf.concat([t1,t2], 1)
    
    > Tensor("ExpandDims_26:0", shape=(3, 1), dtype=int32)
    > Tensor("ExpandDims_27:0", shape=(3, 1), dtype=int32)
    

    如果想要拼接,必须要调用tf.expand_dims()来扩维:

    2. axis:

    第0个维度代表最外面的括号所在的维度,第1个维度代表最外层括号里面的那层括号所在的维度,以此类推。

    import tensorflow as tf
    
    with tf.Session() as sess:
    	t1 = [[1, 2, 3],  [4, 5, 6]]
    	t2 = [[7, 8, 9],  [10, 11, 12]]
    	print(sess.run(tf.concat([t1, t2], 0)))  
    >  [[ 1  2  3]    [ 4  5  6]   [ 7  8  9]    [10 11 12]]
    
    	print(sess.run(tf.concat([t1, t2], 1)))  
    >  [[ 1  2  3  7  8  9]    [ 4  5  6 10 11 12]]
    
    	print(sess.run(tf.concat([t1, t2], -1)))  
    >  [[ 1  2  3  7  8  9]    [ 4  5  6 10 11 12]]
    

    如果不使用sess.run()运行,就会出现下面的情况:

    import tensorflow as tf
    
    t1 = [[1, 2, 3],  [4, 5, 6]]
    t2 = [[7, 8, 9],  [10, 11, 12]]
    w1 = print(tf.concat([t1, t2], 0))
    > Tensor("concat_29:0", shape=(4, 3), dtype=int32)
    
    w2 = print(tf.concat([t1, t2], 1)
    > Tensor("concat_30:0", shape=(2, 6), dtype=int32)
    
    w3 = print(tf.concat([t1, t2], -1))
    > Tensor("concat_31:0", shape=(2, 6), dtype=int32)
    

    python课程推荐。
    在这里插入图片描述

  • 相关阅读:
    Java面试题:栈和队列的实现
    Java面试题:如何对HashMap按键值排序
    经典的Java基础面试题集锦
    9个Java初始化和回收的面试题
    20个高级Java面试题汇总
    Spring、Spring MVC、MyBatis整合文件配置详解2
    Spring、Spring MVC、MyBatis整合文件配置详解
    Spring:基于注解的Spring MVC
    margin百分比的相对值--宽度!
    jquery.cxSelect插件,城市没单位
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13302857.html
Copyright © 2011-2022 走看看