zoukankan      html  css  js  c++  java
  • 自定义网络搭建

    使用到的API有:keras.Sequential、Layers/Model

    1.keras.Sequential

    以前的代码已经很多次用到了这个接口,这里直接给出代码:

    model = Sequential([
        layers.Dense(256,activation=tf.nn.relu), # [b,784] ==>[b,256]
        layers.Dense(128,activation=tf.nn.relu),
        layers.Dense(64,activation=tf.nn.relu),
        layers.Dense(32,activation=tf.nn.relu),
        layers.Dense(10)
    ])
    
    model.build(input_shape=[None,28*28])
    model.summary()

    Sequential还可以通过一些API去管理参数,如:model.trainable_variables、model.call(),前者是用来获取网络中所有的可训练参数,后者则是相当于逐层调model方法

    2.Layer/Model

    Layer的全路径为keras.layers.Layer,Model的全路径为keras.Model(包含compile,fit,evaluate功能)

    class MyDense(keras.layers.Layer):
        def __init__(self,inp_dim,outp_dim):
            super(MyDense, self).__init__()
    
            self.kernel = self.add_variable('w',[inp_dim,outp_dim])
            self.bias = self.add_variable('b',[outp_dim])
    
        def call(self,inputs,training=None):
            out = inputs @ self.kernel + self.bias
    
            return out
        
    
    class MyModel(keras.Model):
        def __init__(self):
            super(MyModel, self).__init__()
            
            self.fc1 = MyDense(28*28,256)
            self.fc2 = MyDense(256, 128)
            self.fc3 = MyDense(128, 64)
            self.fc4 = MyDense(64, 32)
            self.fc5 = MyDense(32, 10)
            
        def call(self,inputs,training=None):
            x = self.fc1(inputs)
            x = tf.nn.relu(x)
            x = self.fc2(x)
            x = tf.nn.relu(x)
            x = self.fc3(x)
            x = tf.nn.relu(x)
            x = self.fc4(x)
            x = tf.nn.relu(x)
            x = self.fc5(x)
            
            return x
  • 相关阅读:
    水壶-[Kruskal重构树] [解题报告]
    线性求逆元推导
    边界线与两端对齐
    左边竖条的实现方法
    $.ajax()知识
    area热点区域
    AJAX与XMLHttpRequest
    js运行机制
    优先级
    各种图形
  • 原文地址:https://www.cnblogs.com/zdm-code/p/12245906.html
Copyright © 2011-2022 走看看