zoukankan      html  css  js  c++  java
  • keras中的mask操作

    使用背景

    最常见的一种情况, 在NLP问题的句子补全方法中, 按照一定的长度, 对句子进行填补和截取操作. 一般使用keras.preprocessing.sequence包中的pad_sequences方法, 在句子前面或者后面补0. 但是这些零是我们不需要的, 只是为了组成可以计算的结构才填补的. 因此计算过程中, 我们希望用mask的思想, 在计算中, 屏蔽这些填补0值得作用. keras中提供了mask相关的操作方法.

    原理

    在keras中, Tensor在各层之间传递, Layer对象接受的上层Layer得到的Tensor, 输出的经过处理后的Tensor.

    keras是用一个mask矩阵来参与到计算当中, 决定在计算中屏蔽哪些位置的值. 因此mask矩阵其中的值就是True/False, 其形状一般与对应的Tensor相同. 同样与Tensor相同的是, mask矩阵也会在每层Layer被处理, 得到传入到下一层的mask情况.

    使用方法

    1. 最直接的, 在NLP问题中, 对句子填补之后, 就要输入到Embedding层中, 将tokenid转换成对应的vector. 我们希望被填补的0值在后续的计算中不产生影响, 就可以在初始化Embedding层时指定参数mask_zeroTrue, 意思就是屏蔽0值, 即填补的0值.

      Embedding层中的compute_mask方法中, 会计算得到mask矩阵. 虽然在Embedding层中不会使用这个mask矩阵, 即0值还是会根据其对应的向量进行查找, 但是这个mask矩阵会被传入到下一层中, 如果下一层, 或之后的层会对mask进行考虑, 那就会起到对应的作用.

    2. 也可以在keras.layers包中引用Masking类, 使用mask_value指定固定的值被屏蔽. 在调用call方法时, 就会输出屏蔽后的结果.

      需要注意的是Masking这种层的compute_mask方法, 源码如下:

      def compute_mask(self, inputs, mask=None):
          output_mask = K.any(K.not_equal(inputs, self.mask_value), axis=-1)
          return output_mask
      

      可以看到, 这一层输出的mask矩阵, 是根据这层的输入得到的, 具体的说是会比输入第一个维度, 这是因为最后一个维度被K.any(axis=-1)给去掉了. 在使用时需要注意这种操作的意义以及维度的变化.

    自定义使用方法

    更多的, 我们还是在自定义的层中, 需要支持mask操作, 因此需要对应的逻辑.


    首先, 如果我们希望自定义的这个层支持mask操作, 就需要在__init__方法中指定:

    self.supports_masking = True
    

    如果在本层计算中需要使用到mask, 则call方法需要多传入一个mask 参数, 即:

    def call(self, inputs, mask=None):
        pass
    

    然后, 如果还要继续输出mask, 供之后的层使用, 如果不对mask矩阵进行变换, 这不用进行任何操作, 否则就需要实现compute_mask函数:

    def compute_mask(self, inputs, mask=None):
        pass
    

    这里的inputs就是输入的Tensor, 与call方法中接收到的一样, mask就是上层传入的mask矩阵.

    如果希望mask到此为止, 之后的层不再使用, 则该函数直接返回None即可:

    def compute_mask(self, inputs, mask=None):
        return None
    

    参考资料

    Keras自定义实现带masking的meanpooling层

    Keras实现支持masking的Flatten层

  • 相关阅读:
    Fluent NHibernate other example
    Fluent NHibernate example
    csharp:Chart
    csharp: Socket
    javascript:Bing Maps AJAX Control, Version 7.0
    csharp: NHibernate and Entity Framework (EF) (object-relational mapper)
    csharp:正则表达式采集网页数据
    ASP.NET AJAX Control Toolkit
    算法习题---5.11邮件传输代理的交互(Uva814)
    STM32---喜提点灯
  • 原文地址:https://www.cnblogs.com/databingo/p/9339175.html
Copyright © 2011-2022 走看看