zoukankan      html  css  js  c++  java
  • tf.expand_dims()和tf.squeeze()

    1.tf.expand_dims()

    tf.expand_dims(input, axis=None, name=None, dim=None)

    作用:给定张量,输入形状的维度索引轴处插入1的尺寸。 尺寸索引轴从零开始; 如果为指定的轴为负数,则从末尾开始算起。

    参数:

    1. input:张量。
    2. aixs:0-D(标量),指定扩大输入形状的维度索引。
    3. name:输出名称Tensor。
    4. dim:0-D(标量), 等同于轴,不推荐使用。

    返回:具有与输入相同数据的张量,但其形状添加了尺寸为1的附加尺寸。

      如果要将批次尺寸添加到单个元素,此操作很有用。 例如,如果您有一个形状为[[height,width,channels]`的图像,则可以将其与具有`expand_dims(image,0)`的1张图像一起批处理,这将使形状为[1,height ,width,channels]。

    # 't' is a tensor of shape [2]
    tf.shape(tf.expand_dims(t, 0))  # [1, 2]
    tf.shape(tf.expand_dims(t, 1))  # [2, 1]
    tf.shape(tf.expand_dims(t, -1))  # [2, 1]
    
    # 't2' is a tensor of shape [2, 3, 5]
    tf.shape(tf.expand_dims(t2, 0))  # [1, 2, 3, 5]
    tf.shape(tf.expand_dims(t2, 2))  # [2, 3, 1, 5]
    tf.shape(tf.expand_dims(t2, 3))  # [2, 3, 5, 1]
    ```
    
    This operation requires that:
    
    `-1-input.dims() <= dim <= input.dims()`
    
    This operation is related to `squeeze()`, which removes dimensions of
    size 1.
    
    Args:
      input: A `Tensor`.
      axis: 0-D (scalar). Specifies the dimension index at which to
        expand the shape of `input`. Must be in the range
        `[-rank(input) - 1, rank(input)]`.
      name: The name of the output `Tensor`.
      dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
    
    Returns:
      A `Tensor` with the same data as `input`, but its shape has an additional
      dimension of size 1 added.
    
    Raises:
      ValueError: if both `dim` and `axis` are specified.

    bert中源码:

    # 该函数默认输入的形状为【batch_size, seq_length, input_num】
    # 如果输入为2D的【batch_size, seq_length】,则扩展到【batch_size, seq_length, 1】
    if input_ids.shape.ndims == 2:
      input_ids = tf.expand_dims(input_ids, axis=[-1])

    2.tf.squeeze()

    tf.squeeze(input, squeeze_dims=None, name=None)
    

    作用:给定张量输入,此操作返回相同类型的张量,并删除所有尺寸为1的维度。 如果不想删除所有尺寸为1的维度,可以通过指定squeeze_dims来删除特定尺寸的维度。
    参数:

    1. input:要挤压的张量
    2. squeeze_dims:
      1. 可选的ints列表, 默认为[]。
      2. 如果指定,只能挤压列出的维度。
      3. 维度索引从0开始,挤压不是1的维度是一个错误
    3. name:操作的名称(可选)

    返回:与输入的类型相同。 包含与输入相同的数据,但具有一个或多个尺寸为1的维度被删除。

    举例:

    import tensorflow as tf
    sess = tf.InteractiveSession()
    t1 = tf.constant([1,2,3,4,5,6],shape=[1,2,3,1])
    print('t1')
    print(t1.eval())
    t2 = tf.squeeze(t1)
    print('t2')
    print(t2.eval())
    t3 = tf.squeeze(t1,[3])
    print('t3')
    print(t3.eval())

    参考文献:

    【1】tensorflow 笔记14:tf.expand_dims和tf.squeeze函数 - 细雨微光 - 博客园

  • 相关阅读:
    小结css2与css3的区别
    javascript变量的作用域
    javascript面向对象
    小结php中几种网页跳转
    foreach
    post与get,这两人到底神马区别??
    typescript遍历Map
    dataTable.js参数
    showModal()和show()的区别
    javascript中location.protocol、location.hostname和location.port
  • 原文地址:https://www.cnblogs.com/nxf-rabbit75/p/12095669.html
Copyright © 2011-2022 走看看