zoukankan      html  css  js  c++  java
  • Keras函数式API介绍

    参考文献:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. O'Reilly Media, 2019.

    Keras的Sequential顺序模型可以快速搭建简易的神经网络,同时Keras也提供函数式API(Functional API)用于定制各种不同类型的网格结构。

    Concatenate

    在搭建Wide & Deep神经网格的时候,需要进行层融合(concatenation)。图中的融合层将Input层和hidden layer最后一层相加在一起。

    input_ = keras.layers.Input(shape=X_train.shape[1:])
    hidden1 = keras.layers.Dense(30, activation="relu")(input_)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_, hidden2])
    output = keras.layers.Dense(1)(concat)
    model = keras.models.Model(inputs=[input_], outputs=[output])
    

    Multi-inputs

    可以将输入特征先分成多组(可以有重叠部分),让它们分别通过神经网络中的不同路径。

    input_A = keras.layers.Input(shape=[5], name="wide_input")
    input_B = keras.layers.Input(shape=[6], name="deep_input")
    hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_A, hidden2])
    output = keras.layers.Dense(1, name="output")(concat)
    model = keras.models.Model(inputs=[input_A, input_B], outputs=[output])
    

    Multi-outputs

    input_A = keras.layers.Input(shape=[5], name="wide_input")
    input_B = keras.layers.Input(shape=[6], name="deep_input")
    hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_A, hidden2])
    output = keras.layers.Dense(1, name="main_output")(concat)
    aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
    model = keras.models.Model(inputs=[input_A, input_B],
                               outputs=[output, aux_output])
    

    每个output可以单独设置损失函数

    model.compile(loss=[“mse”,”mse”], loss_weights=[0.9, 0.1], optimizer=“sgd”)
    

    如果不设置的话,Keras默认使用相同的损失函数。训练中,Keras会单独计算两个损失函数,相加一起得到作为最后的损失值。

  • 相关阅读:
    MCPD 70536题目 自定义打印参数
    《ERP从内部集成起步》读书笔记——第5章 MRP系统的时间概念 5.1 时间三要素 5.1.2 时段
    Jquey拖拽控件Draggable用法
    MCPD 70536题目 反射
    MCPD 70536题目 非托管资源 释放
    VS2008创建Silverlight项目时出错解决方法
    程序猿去旅行
    EntityFramework5.0 数据迁移笔记解决模型变化重建数据库的问题
    完美生活
    一直很安静
  • 原文地址:https://www.cnblogs.com/yaos/p/14014172.html
Copyright © 2011-2022 走看看