zoukankan      html  css  js  c++  java
  • keras中激活函数自定义(以mish函数为列)

     若使用keras框架直接编辑函数调用会导致编译错误。因此,有2种方法可以实现keras的调用,其一使用lamda函数调用,

    其二使用继承Layer层调用(如下代码)。如果使用继承layer层调用,那你可以将你想要实现的方式,通过call函数编辑,而

    call函数的参数一般为输入特征变量[batch,h,w,c],具体实现如下代码:

    class Mish(Layer):
    '''
    Mish Activation Function.
    .. math::
    mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
    tanh=(1 - e^{-2x})/(1 + e^{-2x})
    Shape:
    - Input: Arbitrary. Use the keyword argument `input_shape`
    (tuple of integers, does not include the samples axis)
    when using this layer as the first layer in a model.
    - Output: Same shape as the input.
    Examples:
    >>> X_input = Input(input_shape)
    >>> X = Mish()(X_input)
    '''

    def __init__(self, **kwargs):
    super(Mish, self).__init__(**kwargs)
    self.supports_masking = True

    def call(self, inputs):
    return inputs * K.tanh(K.softplus(inputs))

    def get_config(self):
    config = super(Mish, self).get_config()
    return config

    def compute_output_shape(self, input_shape):
    '''
    compute_output_shape(self, input_shape):为了能让Keras内部shape的匹配检查通过,
    这里需要重写compute_output_shape方法去覆盖父类中的同名方法,来保证输出shape是正确的。
    父类Layer中的compute_output_shape方法直接返回的是input_shape这明显是不对的,
    所以需要我们重写这个方法。所以这个方法也是4个要实现的基本方法之一。
    '''
    return input_shape





    有了mish激活函数,该如何调呢?以下代码简单演示其调用方式:

    cov1=conv2d(卷积参数)(input) # 将输入input进行卷积操作
    Mish()(cov1)  # 将卷积结果输入定义的激活类中,实现mish激活







  • 相关阅读:
    postgres 类型转换 cast 转
    postgresql Delete+ join
    输出特定格式的查询内容到文本(不是导出查询结果)
    八步搞定亚马逊中国区HTTPS负载均衡器设置
    这辈子只能碰到一次! 记一次SSL无故被撤消!
    亚马逊S3数据迁移到中国区
    python2 微信三方登录 中文乱码
    GitLab Wiki 内容恢复版本管理
    Django rest_framework 实用技巧
    Django rest_framework 加入时间间隔
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/12839457.html
Copyright © 2011-2022 走看看