zoukankan      html  css  js  c++  java
  • tensorflow

    size

    Tensor 的 大小,长 * 宽;

    tf.size 返回 Tensor,需要 session;

    d1 = tf.random_uniform((3, 2))
    # print(d1.size)    # AttributeError: 'Tensor' object has no attribute 'size'
    size = tf.size(d1)
    sess = tf.Session()
    print(sess.run(size))       # 6

    shape 和 tf.shape 和 get_shape  set_shape

     

    先说结论再看例子

    1. 运行环境不同

    • shape 和 get_shape 返回元组,故无需 Session,可直接获取;
    • 而 tf.shape 返回 Tensor,需要 Session            【只有返回 Tensor 才需要 Session】

    2. 适用对象不同

    • tf.shape 适用于 Tensor,还有 ndarray,list;
    • shape 适用于 Tensor,还有 ndarray;
    • get_shape 只适用于 Tensor;

    代码如下

    ########## tf.shape ##########
    ### 用函数获取,返回 Tensor
    # 针对所有 Tensor,包括 Variable,array、list 也可以
    d5 = tf.shape(tf.random_normal((2, 3)))     ### Tensor
    print(d5)       # Tensor("Shape:0", shape=(2,), dtype=int32)
    d6 = tf.shape(tf.Variable([1. ,2.]))        ### Variable
    n3 = tf.shape(np.array([[1, 2], [3, 4]]))   ### ndarray
    n4 = tf.shape([1, 2])                       ### list
    with tf.Session() as sess1:
        print(sess1.run(d5))        # [2 3]
        print(sess1.run(d6))        # [2]
        print(sess1.run(n3))        # [2 2]
        print(sess1.run(n4))        # [2]
        
    
    ########## shape ##########
    ### 直接获取,返回元组
    # 针对所有 Tensor,包括 Variable,array 也可以
    d1 = tf.random_uniform((3, 2)).shape        ### Tensor
    print(d1)       # (3, 2)
    d2 = tf.Variable([1. ,2.]).shape            ### Variable
    print(d2)       # (2,)
    n1 = np.array([[1, 2], [3, 4]]).shape       ### ndarray
    print(n1)       # (2, 2)
    
    
    ########## get_shape ##########
    ### 直接获取,返回元组
    # 针对所有 Tensor,包括 Variable,不包括 array
    d3 = tf.random_uniform((3, 2)).get_shape()  ### Tensor
    print(d3)       # (3, 2)
    d4 = tf.Variable([1. ,2.]).get_shape()      ### Variable
    print(d4)       # (2,)
    # n2 = np.array([[1, 2], [3, 4]]).get_shape()     ### 报错 AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'

    set_shape 是为一个 Tensor reset shape

    一般只用于设置 placeholder 的尺寸;

    x1 = tf.placeholder(tf.int32)
    x1.set_shape([2, 2])
    print(tf.shape(x1))
    
    sess = tf.Session()
    # print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1,2,3]]}))       ### ValueError: Cannot feed value of shape (1, 4) for Tensor 'Placeholder:0', which has shape '(2, 2)'
    print(sess.run(tf.shape(x1), feed_dict={x1: [[0, 1], [2, 3]]}))

    限制 x1 只能是 (2,2) 的 shape;

    tf.squeeze 和 tf.expand_dims

    def squeeze(input, axis=None, name=None, squeeze_dims=None)

    压缩维度,如果被压缩的维度为 1 维,就去掉该维度,如果该维度不是 1 维,报错

    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
      tf.shape(tf.squeeze(t))  # [2, 3]
    
    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
      tf.shape(tf.squeeze(t, [2, 4]))  # [1, 2, 3, 1]

    axis 和 squeeze_dims 是一个意思, squeeze_dims 已被废弃;

    axis 可取 list 指定多个维度;

    c1 = tf.constant([[1, 3]])
    print(c1.shape)         # (1, 2)
    
    ### 第 0 维 的维度为 1
    c2 = tf.squeeze(c1, squeeze_dims=0)
    print(c2.shape)         # (2,)
    
    c3 = tf.squeeze(c1, axis=0)
    print(c3.shape)         # (2,)

    def expand_dims(input, axis=None, name=None, dim=None)

    在指定维度上增加 1 个维度

    axis 和 dim 是一个意思,dim 已被废弃

    data = tf.constant([[1, 2],
                        [3, 4]])
    print(data.shape)   # (2, 2)            ### 两个维度
    
    data2 = tf.expand_dims(data, dim=1)     ### 在第一个维度上添加一个维度
    print(data2.shape)  # (2, 1, 2)
    
    data3 = tf.expand_dims(data, dim=0)     ### 在第0个维度上添加一个维度
    print(data3.shape)  # (1, 2, 2)
    
    data4 = tf.expand_dims(data, dim=-1)    ### 在最后一个维度上添加一个维度
    print(data4.shape)  # (2, 2, 1)

    tf.concat

    按指定维度进行拼接 

    def concat(values, axis, name="concat") 

    axis 0 表示按列拼接,1 表示按行拼接

    d1 = tf.zeros((2, 3))
    d2 = tf.ones((2, 4))
    
    d3 = tf.concat([d1, d2], axis=1)        # 第 1 个维度
    d4 = tf.concat([d1, d2], axis=-1)       # -1 代表最后一个维度
    
    sess = tf.Session()
    sess.run(d3)
    sess.run(d4)
    print(d3.shape)     # (2, 7)
    print(d4.shape)     # (2, 7)

    参考资料:

    https://blog.csdn.net/m0_37744293/article/details/78254691  tf.shape()与tf.get_shape()

  • 相关阅读:
    Vue 2.x windows环境下安装
    VSCODE官网下载缓慢或下载失败 解决办法
    angular cli 降级
    Win10 VS2019 设置 以管理员身份运行
    XSHELL 连接 阿里云ECS实例
    Chrome浏览器跨域设置
    DBeaver 执行 mysql 多条语句报错
    DBeaver 连接MySql 8.0 报错 Public Key Retrieval is not allowed
    DBeaver 连接MySql 8.0报错 Unable to load authentication plugin 'caching_sha2_password'
    Linux系统分区
  • 原文地址:https://www.cnblogs.com/yanshw/p/12372314.html
Copyright © 2011-2022 走看看