zoukankan      html  css  js  c++  java
  • tensorflow expand_dims和squeeze

      有时我们会碰到升维或降维的需求,比如现在有一个图像样本,形状是 [height, width, channels],我们需要把它输入到已经训练好的模型中做分类,而模型定义的输入变量是一个batch,即形状为 [batch_size, height, width, channels],这时就需要升维了。tensorflow提供了一个方便的升维函数:expand_dims,参数定义如下:

      tf.expand_dims(input, axis=None, name=None, dim=None)

      参数说明:

      input:待升维的tensor

      axis:插入新维度的索引位置

      name:输出tensor名称

      dim: 一般不用

      

    import tensorflow as tf
    
    sess = tf.Session()
    
    t = tf.constant([1, 2, 3], dtype=tf.int32)
    
    t.get_shape()
    # TensorShape([Dimension(3)])
    
    tf.expand_dims(t, 0).get_shape()
    # TensorShape([Dimension(1), Dimension(3)])
    
    tf.expand_dims(t, 1).get_shape()
    # TensorShape([Dimension(3), Dimension(1)])

      squeeze正好执行相反的操作:删除大小是1的维度

      tf.squeeze(input, squeeze_dims=None, name=None)

      input:  待降维的张量

      sequeeze_dims: list[int]类型,表示需要删除的维度索引。默认为[],即删除所以大小为1的维度

    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t)) ==> [2, 3]
    Or, to remove specific size 1 dimensions:
     
    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]

      在处理tensor的时候合理使用这两个函数,能极大的提高效率。例如处理输入样本、执行向量与矩阵的点乘等情况。

    参考:https://blog.csdn.net/qq_31780525/article/details/72280284

      

  • 相关阅读:
    codechef: ADAROKS2 ,Ada Rooks 2
    codechef: BINARY, Binary Movements
    codechef : TREDEG , Trees and Degrees
    ●洛谷P1291 [SHOI2002]百事世界杯之旅
    ●BZOJ 1416 [NOI2006]神奇的口袋
    ●CodeForce 293E Close Vertices
    ●POJ 1741 Tree
    ●CodeForces 480E Parking Lot
    ●计蒜客 百度地图的实时路况
    ●CodeForces 549F Yura and Developers
  • 原文地址:https://www.cnblogs.com/estragon/p/9935148.html
Copyright © 2011-2022 走看看