zoukankan      html  css  js  c++  java
  • keras中激活函数自定义(以mish函数为列)

     若使用keras框架直接编辑函数调用会导致编译错误。因此,有2种方法可以实现keras的调用,其一使用lamda函数调用,

    其二使用继承Layer层调用(如下代码)。如果使用继承layer层调用,那你可以将你想要实现的方式,通过call函数编辑,而

    call函数的参数一般为输入特征变量[batch,h,w,c],具体实现如下代码:

    class Mish(Layer):
    '''
    Mish Activation Function.
    .. math::
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
    tanh=(1 - e^{-2x})/(1 + e^{-2x})
    Shape:
    - Input: Arbitrary. Use the keyword argument `input_shape`
    (tuple of integers, does not include the samples axis)
    when using this layer as the first layer in a model.
    - Output: Same shape as the input.
    Examples:
    >>> X_input = Input(input_shape)
    >>> X = Mish()(X_input)
    '''

    def __init__(self, **kwargs):
    super(Mish, self).__init__(**kwargs)
    self.supports_masking = True

    def call(self, inputs):
    return inputs * K.tanh(K.softplus(inputs))

    def get_config(self):
    config = super(Mish, self).get_config()
    return config

    def compute_output_shape(self, input_shape):
    '''
    compute_output_shape(self, input_shape):为了能让Keras内部shape的匹配检查通过,
    这里需要重写compute_output_shape方法去覆盖父类中的同名方法,来保证输出shape是正确的。
    父类Layer中的compute_output_shape方法直接返回的是input_shape这明显是不对的,
    所以需要我们重写这个方法。所以这个方法也是4个要实现的基本方法之一。
    '''
    return input_shape





    有了mish激活函数,该如何调呢?以下代码简单演示其调用方式:

    cov1=conv2d(卷积参数)(input) # 将输入input进行卷积操作
    Mish()(cov1)  # 将卷积结果输入定义的激活类中,实现mish激活







  • 相关阅读:
    如何在一个for语句中迭代多个对象(2.7)
    yield列表反转 islice切片(2.6)
    yield和生成器, 通过斐波那契数列学习(2.5)
    python实现线程池(2.4)
    LOJ 3120: 洛谷 P5401: 「CTS2019 | CTSC2019」珍珠
    瞎写的理性愉悦:正整数幂和与伯努利数
    bzoj 3328: PYXFIB
    LOJ 3119: 洛谷 P5400: 「CTS2019 | CTSC2019」随机立方体
    洛谷 P5345: 【XR-1】快乐肥宅
    LOJ 3089: 洛谷 P5319: 「BJOI2019」奥术神杖
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/12839457.html
Copyright © 2011-2022 走看看