zoukankan      html  css  js  c++  java
  • tensorflow expand_dims和squeeze

      有时我们会碰到升维或降维的需求,比如现在有一个图像样本,形状是 [height, width, channels],我们需要把它输入到已经训练好的模型中做分类,而模型定义的输入变量是一个batch,即形状为 [batch_size, height, width, channels],这时就需要升维了。tensorflow提供了一个方便的升维函数:expand_dims,参数定义如下:

      tf.expand_dims(input, axis=None, name=None, dim=None)

      参数说明:

      input:待升维的tensor

      axis:插入新维度的索引位置

      name:输出tensor名称

      dim: 一般不用

      

    import tensorflow as tf
    
    sess = tf.Session()
    
    t = tf.constant([1, 2, 3], dtype=tf.int32)
    
    t.get_shape()
    # TensorShape([Dimension(3)])
    
    tf.expand_dims(t, 0).get_shape()
    # TensorShape([Dimension(1), Dimension(3)])
    
    tf.expand_dims(t, 1).get_shape()
    # TensorShape([Dimension(3), Dimension(1)])

      squeeze正好执行相反的操作:删除大小是1的维度

      tf.squeeze(input, squeeze_dims=None, name=None)

      input:  待降维的张量

      sequeeze_dims: list[int]类型,表示需要删除的维度索引。默认为[],即删除所以大小为1的维度

    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t)) ==> [2, 3]
    Or, to remove specific size 1 dimensions:
     
    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]

      在处理tensor的时候合理使用这两个函数,能极大的提高效率。例如处理输入样本、执行向量与矩阵的点乘等情况。

    参考:https://blog.csdn.net/qq_31780525/article/details/72280284

      

  • 相关阅读:
    方法转换IE、Firefox、Chrome区别
    splice方法便签
    webstorm主题网址+使用方法
    从程序员到项目经理(一):没有捷径
    界面原型图绘制工具Pencil
    程序员:伤不起的三十岁
    从程序员到项目经理(三):认识项目经理
    从程序员到项目经理(二):如何胜任
    原型制作软件 Axure RP
    软件界面原型设计工具 UIDesigner
  • 原文地址:https://www.cnblogs.com/estragon/p/9935148.html
Copyright © 2011-2022 走看看