zoukankan      html  css  js  c++  java
  • 『开发技巧』Keras自定义对象(层、评价函数与损失)

    1.自定义层

    对于简单、无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现。但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层。

    这是一个 Keras2.0 中,Keras 层的骨架(如果你用的是旧的版本,请更新到新版)。你只需要实现三个方法即可:

    • build(input_shape): 这是你定义权重的地方。这个方法必须设 self.built = True,可以通过调用 super([Layer], self).build() 完成。
    • call(x): 这里是编写层的功能逻辑的地方。你只需要关注传入 call 的第一个参数:输入张量,除非你希望你的层支持masking。
    • compute_output_shape(input_shape): 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。
    from keras import backend as K
    from keras.engine.topology import Layer
    
    class MyLayer(Layer):
    
        def __init__(self, output_dim, **kwargs):
            self.output_dim = output_dim
            super(MyLayer, self).__init__(**kwargs)
    
        def build(self, input_shape):
            # 为该层创建一个可训练的权重
            self.kernel = self.add_weight(name='kernel', 
                                          shape=(input_shape[1], self.output_dim),
                                          initializer='uniform',
                                          trainable=True)
            super(MyLayer, self).build(input_shape)  # 一定要在最后调用它
    
        def call(self, x):
            return K.dot(x, self.kernel)
    
        def compute_output_shape(self, input_shape):
            return (input_shape[0], self.output_dim)
    

    还可以定义具有多个输入张量和多个输出张量的 Keras 层。 为此,你应该假设方法 build(input_shape)call(x) 和 compute_output_shape(input_shape) 的输入输出都是列表。 这里是一个例子,与上面那个相似:

    from keras import backend as K
    from keras.engine.topology import Layer
    
    class MyLayer(Layer):
    
        def __init__(self, output_dim, **kwargs):
            self.output_dim = output_dim
            super(MyLayer, self).__init__(**kwargs)
    
        def build(self, input_shape):
            assert isinstance(input_shape, list)
            # 为该层创建一个可训练的权重
            self.kernel = self.add_weight(name='kernel',
                                          shape=(input_shape[0][1], self.output_dim),
                                          initializer='uniform',
                                          trainable=True)
            super(MyLayer, self).build(input_shape)  # 一定要在最后调用它
    
        def call(self, x):
            assert isinstance(x, list)
            a, b = x
            return [K.dot(a, self.kernel) + b, K.mean(b, axis=-1)]
    
        def compute_output_shape(self, input_shape):
            assert isinstance(input_shape, list)
            shape_a, shape_b = input_shape
            return [(shape_a[0], self.output_dim), shape_b[:-1]]
    

    已有的 Keras 层就是实现任何层的很好例子。不要犹豫阅读源码!

    2.自定义评价函数

    自定义评价函数应该在编译的时候(compile)传递进去。该函数需要以 (y_true, y_pred) 作为输入参数,并返回一个张量作为输出结果。

    import keras.backend as K
    
    def mean_pred(y_true, y_pred):
        return K.mean(y_pred)
    
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy',
                  metrics=['accuracy', mean_pred])
    

    3.自定义损失函数

    自定义损失函数也应该在编译的时候(compile)传递进去。该函数需要以 (y_true, y_pred) 作为输入参数,并返回一个张量作为输出结果。

    import keras.backend as K
    
    def my_loss(y_true, y_pred):
        return K.mean(K.squre(y_pred-y_true))#以平方差举例
    
    model.compile(optimizer='rmsprop',
                  loss=my_loss,
                  metrics=['accuracy'])
    

    4.处理已保存模型中的自定义层(或其他自定义对象)

    如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制:

    from keras.models import load_model
    # 假设你的模型包含一个 AttentionLayer 类的实例
    model = load_model('my_model.h5', custom_objects={'AttentionLayer': AttentionLayer})
    

    或者,你可以使用 自定义对象作用域

    from keras.utils import CustomObjectScope
    
    with CustomObjectScope({'AttentionLayer': AttentionLayer}):
        model = load_model('my_model.h5')
    
  • 相关阅读:
    QT启动画面不显示
    指针运算,终于明白了
    sourceforge 优秀 开源 项目 介绍
    ios验证邮箱格式
    获取手机当前经纬度的方法
    将UIImageView改成圆角和圆形
    Extensible Messaging and Presence Protocol (XMPP) 的实现
    导入libxml2.dylib时出问题
    Extensible Messaging and Presence Protocol (XMPP) 简介
    XMPPFramework 常用api包简介
  • 原文地址:https://www.cnblogs.com/xiaosongshine/p/11188063.html
Copyright © 2011-2022 走看看