zoukankan      html  css  js  c++  java
  • Keras函数式API介绍

    参考文献:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. O'Reilly Media, 2019.

    Keras的Sequential顺序模型可以快速搭建简易的神经网络,同时Keras也提供函数式API(Functional API)用于定制各种不同类型的网格结构。

    Concatenate

    在搭建Wide & Deep神经网格的时候,需要进行层融合(concatenation)。图中的融合层将Input层和hidden layer最后一层相加在一起。

    input_ = keras.layers.Input(shape=X_train.shape[1:])
    hidden1 = keras.layers.Dense(30, activation="relu")(input_)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_, hidden2])
    output = keras.layers.Dense(1)(concat)
    model = keras.models.Model(inputs=[input_], outputs=[output])
    

    Multi-inputs

    可以将输入特征先分成多组(可以有重叠部分),让它们分别通过神经网络中的不同路径。

    input_A = keras.layers.Input(shape=[5], name="wide_input")
    input_B = keras.layers.Input(shape=[6], name="deep_input")
    hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_A, hidden2])
    output = keras.layers.Dense(1, name="output")(concat)
    model = keras.models.Model(inputs=[input_A, input_B], outputs=[output])
    

    Multi-outputs

    input_A = keras.layers.Input(shape=[5], name="wide_input")
    input_B = keras.layers.Input(shape=[6], name="deep_input")
    hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_A, hidden2])
    output = keras.layers.Dense(1, name="main_output")(concat)
    aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
    model = keras.models.Model(inputs=[input_A, input_B],
                               outputs=[output, aux_output])
    

    每个output可以单独设置损失函数

    model.compile(loss=[“mse”,”mse”], loss_weights=[0.9, 0.1], optimizer=“sgd”)
    

    如果不设置的话,Keras默认使用相同的损失函数。训练中,Keras会单独计算两个损失函数,相加一起得到作为最后的损失值。

  • 相关阅读:
    用扑克牌保存文本信息
    计算机网络7--报文交换
    算法——字符串匹配之BM算法
    Head First Python 学习笔记-Chapter3:文件读取和异常处理
    页面登陆框老是乱乱的?banner跨页图片缩小之后总是在側面不能显示主要部分?哈哈~我来帮你忙~~
    happens-before通俗理解
    Eclipse中Git插件还原文件
    集成 Tomcat 插件到 Eclipse 的过程
    深入理解ClassLoader(五)—类的卸载
    使用eclipse远程调试Tomcat的方法
  • 原文地址:https://www.cnblogs.com/yaos/p/14014172.html
Copyright © 2011-2022 走看看