zoukankan      html  css  js  c++  java
  • tensorflow

    size

    Tensor 的 大小,长 * 宽;

    tf.size 返回 Tensor,需要 session;

    d1 = tf.random_uniform((3, 2))
    # print(d1.size)    # AttributeError: 'Tensor' object has no attribute 'size'
    size = tf.size(d1)
    sess = tf.Session()
    print(sess.run(size))       # 6

    shape 和 tf.shape 和 get_shape  set_shape

     

    先说结论再看例子

    1. 运行环境不同

    • shape 和 get_shape 返回元组,故无需 Session,可直接获取;
    • 而 tf.shape 返回 Tensor,需要 Session            【只有返回 Tensor 才需要 Session】

    2. 适用对象不同

    • tf.shape 适用于 Tensor,还有 ndarray,list;
    • shape 适用于 Tensor,还有 ndarray;
    • get_shape 只适用于 Tensor;

    代码如下

    ########## tf.shape ##########
    ### 用函数获取,返回 Tensor
    # 针对所有 Tensor,包括 Variable,array、list 也可以
    d5 = tf.shape(tf.random_normal((2, 3)))     ### Tensor
    print(d5)       # Tensor("Shape:0", shape=(2,), dtype=int32)
    d6 = tf.shape(tf.Variable([1. ,2.]))        ### Variable
    n3 = tf.shape(np.array([[1, 2], [3, 4]]))   ### ndarray
    n4 = tf.shape([1, 2])                       ### list
    with tf.Session() as sess1:
        print(sess1.run(d5))        # [2 3]
        print(sess1.run(d6))        # [2]
        print(sess1.run(n3))        # [2 2]
        print(sess1.run(n4))        # [2]
        
    
    ########## shape ##########
    ### 直接获取,返回元组
    # 针对所有 Tensor,包括 Variable,array 也可以
    d1 = tf.random_uniform((3, 2)).shape        ### Tensor
    print(d1)       # (3, 2)
    d2 = tf.Variable([1. ,2.]).shape            ### Variable
    print(d2)       # (2,)
    n1 = np.array([[1, 2], [3, 4]]).shape       ### ndarray
    print(n1)       # (2, 2)
    
    
    ########## get_shape ##########
    ### 直接获取,返回元组
    # 针对所有 Tensor,包括 Variable,不包括 array
    d3 = tf.random_uniform((3, 2)).get_shape()  ### Tensor
    print(d3)       # (3, 2)
    d4 = tf.Variable([1. ,2.]).get_shape()      ### Variable
    print(d4)       # (2,)
    # n2 = np.array([[1, 2], [3, 4]]).get_shape()     ### 报错 AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

    set_shape 是为一个 Tensor reset shape

    一般只用于设置 placeholder 的尺寸;

    x1 = tf.placeholder(tf.int32)
    x1.set_shape([2, 2])
    print(tf.shape(x1))
    
    sess = tf.Session()
    # print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1,2,3]]}))       ### ValueError: Cannot feed value of shape (1, 4) for Tensor 'Placeholder:0', which has shape '(2, 2)'
    print(sess.run(tf.shape(x1), feed_dict={x1: [[0, 1], [2, 3]]}))

    限制 x1 只能是 (2,2) 的 shape;

    tf.squeeze 和 tf.expand_dims

    def squeeze(input, axis=None, name=None, squeeze_dims=None)

    压缩维度,如果被压缩的维度为 1 维,就去掉该维度,如果该维度不是 1 维,报错

    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
      tf.shape(tf.squeeze(t))  # [2, 3]
    
    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
      tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]

    axis 和 squeeze_dims 是一个意思, squeeze_dims 已被废弃;

    axis 可取 list 指定多个维度;

    c1 = tf.constant([[1, 3]])
    print(c1.shape)         # (1, 2)
    
    ### 第 0 维 的维度为 1
    c2 = tf.squeeze(c1, squeeze_dims=0)
    print(c2.shape)         # (2,)
    
    c3 = tf.squeeze(c1, axis=0)
    print(c3.shape)         # (2,)

    def expand_dims(input, axis=None, name=None, dim=None)

    在指定维度上增加 1 个维度

    axis 和 dim 是一个意思,dim 已被废弃

    data = tf.constant([[1, 2],
                        [3, 4]])
    print(data.shape)   # (2, 2)            ### 两个维度
    
    data2 = tf.expand_dims(data, dim=1)     ### 在第一个维度上添加一个维度
    print(data2.shape)  # (2, 1, 2)
    
    data3 = tf.expand_dims(data, dim=0)     ### 在第0个维度上添加一个维度
    print(data3.shape)  # (1, 2, 2)
    
    data4 = tf.expand_dims(data, dim=-1)    ### 在最后一个维度上添加一个维度
    print(data4.shape)  # (2, 2, 1)

    tf.concat

    按指定维度进行拼接 

    def concat(values, axis, name="concat") 

    axis 0 表示按列拼接,1 表示按行拼接

    d1 = tf.zeros((2, 3))
    d2 = tf.ones((2, 4))
    
    d3 = tf.concat([d1, d2], axis=1)        # 第 1 个维度
    d4 = tf.concat([d1, d2], axis=-1)       # -1 代表最后一个维度
    
    sess = tf.Session()
    sess.run(d3)
    sess.run(d4)
    print(d3.shape)     # (2, 7)
    print(d4.shape)     # (2, 7)

    参考资料:

    https://blog.csdn.net/m0_37744293/article/details/78254691  tf.shape()与tf.get_shape()

  • 相关阅读:
    QT QT程序初练
    Linux Shell编程三
    Linux Shell编程二
    git操作
    git push命令
    Zabbix全方位告警接入-电话/微信/短信都支持
    CentOS 7安装TigerVNC Server
    MySQL各版本的区别
    MariaDB Galera Cluster 部署
    MySQL高可用方案MHA的部署和原理
  • 原文地址:https://www.cnblogs.com/yanshw/p/12372314.html
Copyright © 2011-2022 走看看