zoukankan      html  css  js  c++  java
  • 解决keras.backend.reshape中的错误ValueError: Tried to convert 'shape' to a tensor and failed. Error: Cannot convert a partially known TensorShape to a Tensor

    许多CNN网络都有Fusion layer作为融合层,比如:

     参考:https://arxiv.org/pdf/1712.03400.pdf

    相关代码:(https://github.com/baldassarreFe/deep-koalarization/blob/master/src/koalarization/fusion_layer.py)

    class FusionLayer(Layer):
        def call(self, inputs, mask=None):
            imgs, embs = inputs
            reshaped_shape = imgs.shape[:3].concatenate(embs.shape[1])
            embs = K.repeat(embs, imgs.shape[1] * imgs.shape[2])
            embs = K.reshape(embs, reshaped_shape)
            return K.concatenate([imgs, embs], axis=3)

    当我实际去做的时候, K.reshape 报错:ValueError: Tried to convert 'shape' to a tensor and failed. Error: Cannot convert a partially known TensorShape to a Tensor

     reshaped_shape = enco_loco.shape[:3].concatenate(enco_glob.shape[1])
        fuse = K.repeat(enco_glob, enco_loco.shape[1]*enco_loco.shape[2])
        fuse = K.reshape(fuse, (reshaped_shape))
        fuse = K.concatenate([enco_loco, fuse], axis=3)

    相关信息:

    enco_loco:(None, 16, 16, 512)
    enco_glob:(None, 512)
    reshaped_shape:(None, 16, 16, 512)
    enco_glob.shape:(None, 512)
    fuse.shape:(None, 256, 512)

    最后想把fuse从(None, 256, 512) 变成(None, 16, 16, 512) 就出现上述错误。

    解决过程:

    fuse = K.reshape(fuse, (-1, reshaped_shape[1], reshaped_shape[2], reshaped_shape[3]))

    参考:https://github.com/matterport/Mask_RCNN/issues/1070

    但又出现错误:AttributeError 'NoneType' object has no attribute '_inbound_nodes'

    原来是因为:“只要使用Model,就必须保证该函数内全为layer而不能有其他函数,如果有其他函数必须用Lambda封装为layer。”

    参考:https://zhuanlan.zhihu.com/p/138075621

    好吧,再改一下:

    from keras.layers import  RepeatVector, Reshape
    from keras.layers.merge import concatenate
    
    reshaped_shape = enco_loco.shape[:3].concatenate(enco_glob.shape[1])
        fuse = RepeatVector(enco_loco.shape[1]*enco_loco.shape[2])(enco_glob)
        fuse = Reshape(( reshaped_shape[1], reshaped_shape[2], reshaped_shape[3]))(fuse)
        fuse = concatenate([enco_loco, fuse], axis=3)

    注意这里的维度必须是( reshaped_shape[1], reshaped_shape[2], reshaped_shape[3]) 而不是 ( -1, reshaped_shape[1], reshaped_shape[2], reshaped_shape[3])

    不然会出错

  • 相关阅读:
    Sql Server 邮件日志 操作 IT
    导出Excel IT
    Sqlserver 2005 修改数据库默认排序 IT
    SqlServer 备份数据库语法 IT
    HDFS常用shell命令
    改写UMFPACK算例中的压缩方式(动态)
    umFPACK使用调用(一)
    改写UMFPACK算例中的压缩方式(静态)
    利用C/C++实现从文件读入到子程序中调用返回结果
    改写UMFPACK算例中的压缩方式
  • 原文地址:https://www.cnblogs.com/mrlonely2018/p/13971791.html
Copyright © 2011-2022 走看看