zoukankan      html  css  js  c++  java
  • tensorflow2.0 squeeze出错

    用tf.keras写了自定义层,但在调用自定义层的时候总是报错,找了好久才发现问题所在,所以记下此问题。

    问题代码

    u=tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)
    

    其中inputs的第一维为None,这里的代码为自定义的前向传播。我是想将得到的输出张量维度为1的维度删掉,因此调用了tf.squeeze方法,这时虽然没有报错但出现了问题。我分别打印了下面内容。

    print(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3).shape)
    print(self.kernel.shape)
    print((tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel).shape)
    print(tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1))@self.kernel,axis=3))
    

    可以发现,当张量第一维为None的时候tf.squeeze使结果变为了0。我想要的结果是删除第三个输出的大小为1的维度,即得到下面的结果

    解决使用tf.squeeze的时候加上删除的维度。

    tf.squeeze(tf.expand_dims(tf.expand_dims(inputs,axis=1),axis=3)@self.kernel,axis=3)
    
  • 相关阅读:
    vue通过input选取图片,jq的ajax向服务器上传img
    IDEA常用快捷键
    JavaScript预解析
    jQuery实现颜色打字机
    MVC超链接调用控制器内的方法
    jQuery实现鼠标移入切换图片
    聚类算法
    并行K-Means
    [Err] 1055
    地图匹配实践
  • 原文地址:https://www.cnblogs.com/lolybj/p/11581917.html
Copyright © 2011-2022 走看看