参考文献:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. Reilly Media, 2019.
除了使用函数API外,还可以通过子类(subclass)自定义神经网络模型。
假设要搭建如图所示的神经网格,使用函数API:
input_A = keras.layers.Input(shape=[5], name="wide_input")
input_B = keras.layers.Input(shape=[6], name="deep_input")
hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
output = keras.layers.Dense(1, name="main_output")(concat)
aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
model = keras.models.Model(inputs=[input_A, input_B],
outputs=[output, aux_output])
换成子类API,
class WideAndDeepModel(keras.models.Model):
def __init__(self, units=30, activation="relu", **kwargs):
super().__init__(**kwargs)
self.hidden1 = keras.layers.Dense(units, activation=activation)
self.hidden2 = keras.layers.Dense(units, activation=activation)
self.main_output = keras.layers.Dense(1)
self.aux_output = keras.layers.Dense(1)
def call(self, inputs):
input_A, input_B = inputs
hidden1 = self.hidden1(input_B)
hidden2 = self.hidden2(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
main_output = self.main_output(concat)
aux_output = self.aux_output(hidden2)
return main_output, aux_output
初始化模型并编译
model = WideAndDeepModel(30, activation="relu")
model.compile(loss="mse", loss_weights=[0.9, 0.1], optimizer=keras.optimizers.SGD(lr=1e-3))
和函数式API不同,使用子类搭建的神经网络,如果运行model.summary,系统会报错
ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.
这是因为通过子类搭建的网络中不存在graph,即没有网络层之间的连接信息,因此无法使用model.summary()
。如果想要使用model.summary()
,有两种方法:
第一种方法比较别扭,就是先读入数据训练一次,
history = model.fit((X_train_A, X_train_B), (y_train, y_train), epochs=2,
validation_data=((X_valid_A, X_valid_B), (y_valid, y_valid)))
再运行model.summary
就可以输出模型信息
Model: "wide_and_deep_model_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_28 (Dense) multiple 210
_________________________________________________________________
dense_29 (Dense) multiple 930
_________________________________________________________________
dense_30 (Dense) multiple 36
_________________________________________________________________
dense_31 (Dense) multiple 31
=================================================================
Total params: 1,207
Trainable params: 1,207
Non-trainable params: 0
_________________________________________________________________
不同于通过子类API搭建的模型,使用model.summary()无法输出神经网络的详细信息,这是子类API的缺点。
第二种方法其实报错信息里提示,就是需要先运行一次模型build
,输入神经网络的input shape。需要注意的是,这是一个Multi-Inputs神经网格,因此input shape是一个列表
model.build([(None, 5),(None, 6)])
之后再运行一次model.summary()
就不会报错。