zoukankan      html  css  js  c++  java
  • Keras通过子类(subclass)自定义神经网络模型

    参考文献:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. Reilly Media, 2019.

    除了使用函数API外,还可以通过子类(subclass)自定义神经网络模型。

    假设要搭建如图所示的神经网格,使用函数API:

    input_A = keras.layers.Input(shape=[5], name="wide_input")
    input_B = keras.layers.Input(shape=[6], name="deep_input")
    hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_A, hidden2])
    output = keras.layers.Dense(1, name="main_output")(concat)
    aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
    model = keras.models.Model(inputs=[input_A, input_B],
                               outputs=[output, aux_output])

    换成子类API,

    class WideAndDeepModel(keras.models.Model):
        def __init__(self, units=30, activation="relu", **kwargs):
            super().__init__(**kwargs)
            self.hidden1 = keras.layers.Dense(units, activation=activation)
            self.hidden2 = keras.layers.Dense(units, activation=activation)
            self.main_output = keras.layers.Dense(1)
            self.aux_output = keras.layers.Dense(1)
            
        def call(self, inputs):
            input_A, input_B = inputs
            hidden1 = self.hidden1(input_B)
            hidden2 = self.hidden2(hidden1)
            concat = keras.layers.concatenate([input_A, hidden2])
            main_output = self.main_output(concat)
            aux_output = self.aux_output(hidden2)
            return main_output, aux_output

    初始化模型并编译

    model = WideAndDeepModel(30, activation="relu")
    model.compile(loss="mse", loss_weights=[0.9, 0.1], optimizer=keras.optimizers.SGD(lr=1e-3))

    和函数式API不同,使用子类搭建的神经网络,如果运行model.summary,系统会报错

    ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.

    这是因为通过子类搭建的网络中不存在graph,即没有网络层之间的连接信息,因此无法使用model.summary() 。如果想要使用model.summary(),有两种方法:
    第一种方法比较别扭,就是先读入数据训练一次,

    history = model.fit((X_train_A, X_train_B), (y_train, y_train), epochs=2,
                        validation_data=((X_valid_A, X_valid_B), (y_valid, y_valid)))

    再运行model.summary就可以输出模型信息

    Model: "wide_and_deep_model_4"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_28 (Dense)             multiple                  210       
    _________________________________________________________________
    dense_29 (Dense)             multiple                  930       
    _________________________________________________________________
    dense_30 (Dense)             multiple                  36        
    _________________________________________________________________
    dense_31 (Dense)             multiple                  31        
    =================================================================
    Total params: 1,207
    Trainable params: 1,207
    Non-trainable params: 0
    _________________________________________________________________

    不同于通过子类API搭建的模型,使用model.summary()无法输出神经网络的详细信息,这是子类API的缺点。
    第二种方法其实报错信息里提示,就是需要先运行一次模型build,输入神经网络的input shape。需要注意的是,这是一个Multi-Inputs神经网格,因此input shape是一个列表

    model.build([(None, 5),(None, 6)])

    之后再运行一次model.summary()就不会报错。

  • 相关阅读:
    Ralasafe基于策略模型
    如何让Oracle表字段自动增长
    Oracle中Number类型字段使用.netTiers和CodeSmith问题的解决方案
    GridView的DataFormatString参考
    解决.NET连接Oracle数据库的一些问题(转)
    C# WinForm开发系列 DataGridView
    C# 插件式程序开发
    Oracle中“字符串中的字符大小写敏感处理方法”
    做一个项目,平时都用到哪些工具提高效率(中)
    折腾了这么多年的.NET开发,也只学会了这么几招 软件开发不是生活的全部,但是好的生活全靠它了(转)
  • 原文地址:https://www.cnblogs.com/yaos/p/14014173.html
Copyright © 2011-2022 走看看