zoukankan      html  css  js  c++  java
  • 自定义层or网络

    Outline

    • keras.Sequential

    • keras.layers.Layer

    • keras.Model

    keras.Sequential

    • model.trainable_variables # 管理参数

    • model.call()

    network = Sequential([
        layers.Dense(256, acitvaiton='relu'),
        layers.Dense(128, acitvaiton='relu'),
        layers.Dense(64, acitvaiton='relu'),
        layers.Dense(32, acitvaiton='relu'),
        layers.Dense(10)
    ])
    network.build(input_shape=(None, 28 * 28))
    network.summary()
    

    Layer/Model

    • Inherit from keras.layers.Layer/keras.Model

    • __init__

    • call

    • Model:compile/fit/evaluate

    MyDense

    class MyDense(layers.Layer):
        def __init__(self, inp_dim, outp_dim):
            super(MyDense, self).__init__()
    
            self.kernel = self.add_variable('w', [imp_dim, outp_dim])
            self.bias = self.add_variable('b', [outp_dim])
    
        def call(self, inputs, training=None):
            out = input @ self.kernel + self.bias
    
            return out
    

    MyModel

    class MyModel(keras.Model):
        def __init__(self):
            super(MyModel, self).__init__()
            self.fc1 = MyDense(28 * 28, 256)
            self.fc2 = MyDense(256, 128)
            self.fc3 = MyDense(128, 64)
            self.fc4 = MyDense(64, 32)
            self.fc5 = MyDense(32, 10)
    
        def call(self, iputs, training=None):
            x = self.fc1(inputs)
            x = tf.nn.relu(x)
            x = self.fc2(x)
            x = tf.nn.relu(x)
            x = self.fc3(x)
            x = tf.nn.relu(x)
            x = self.fc4(x)
            x = tf.nn.relu(x)
            x = self.fc5(x)
    
            return x
    
  • 相关阅读:
    第五次站立会议
    第四次站立会议
    第三次晚间站立总结会议
    易校小程序典型用户需求分析
    第三次站立会议
    第二次晚间站立总结会议
    第二次站立会议
    第一次晚间站立总结会议
    MyBatis注解
    延迟加载与缓存
  • 原文地址:https://www.cnblogs.com/nickchen121/p/10922806.html
Copyright © 2011-2022 走看看