zoukankan      html  css  js  c++  java
  • tensorflow expand_dims和squeeze

      有时我们会碰到升维或降维的需求,比如现在有一个图像样本,形状是 [height, width, channels],我们需要把它输入到已经训练好的模型中做分类,而模型定义的输入变量是一个batch,即形状为 [batch_size, height, width, channels],这时就需要升维了。tensorflow提供了一个方便的升维函数:expand_dims,参数定义如下:

      tf.expand_dims(input, axis=None, name=None, dim=None)

      参数说明:

      input:待升维的tensor

      axis:插入新维度的索引位置

      name:输出tensor名称

      dim: 一般不用

      

    import tensorflow as tf
    
    sess = tf.Session()
    
    t = tf.constant([1, 2, 3], dtype=tf.int32)
    
    t.get_shape()
    # TensorShape([Dimension(3)])
    
    tf.expand_dims(t, 0).get_shape()
    # TensorShape([Dimension(1), Dimension(3)])
    
    tf.expand_dims(t, 1).get_shape()
    # TensorShape([Dimension(3), Dimension(1)])

      squeeze正好执行相反的操作:删除大小是1的维度

      tf.squeeze(input, squeeze_dims=None, name=None)

      input:  待降维的张量

      sequeeze_dims: list[int]类型,表示需要删除的维度索引。默认为[],即删除所以大小为1的维度

    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t)) ==> [2, 3]
    Or, to remove specific size 1 dimensions:
     
    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]

      在处理tensor的时候合理使用这两个函数,能极大的提高效率。例如处理输入样本、执行向量与矩阵的点乘等情况。

    参考:https://blog.csdn.net/qq_31780525/article/details/72280284

      

  • 相关阅读:
    牛客练习赛53 A-E
    算导第二章笔记 (归并排序 之 插入排序优化)
    LightOJ 1372 (枚举 + 树状数组)
    LightOJ 1348 (树链剖分 + 线段树(树状数组))
    Light OJ 1343
    Light OJ 1266
    Light OJ 1085
    CodeForces 671C
    Codeforces Round #352 (Div. 2) (A-D)
    ZOJ1008
  • 原文地址:https://www.cnblogs.com/estragon/p/9935148.html
Copyright © 2011-2022 走看看