zoukankan      html  css  js  c++  java
  • tensorflow(二十八):Keras自定义层,继承layer,model

    一、讲解

     

     

     

     

     

    二、代码

    import tensorflow as tf
    from tensorflow.python.keras import datasets, layers, optimizers, Sequential, metrics
    from tensorflow.python import keras
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    def preprocess(x, y):
        """
        x is a simple image, not a batch
        :param x:
        :param y:
        :return:
        """
        x = tf.cast(x, dtype=tf.float32) / 255.
        x = tf.reshape(x, [28*28])
        y = tf.cast(y, dtype=tf.int32)
        y = tf.one_hot(y, depth=10)
        return x, y
    
    batchsz = 128
    (x, y), (x_val, y_val) = datasets.mnist.load_data()
    print("datasets: ", x.shape, y.shape, x.min(), x.max())
    
    
    db = tf.data.Dataset.from_tensor_slices((x, y))
    db = db.map(preprocess).shuffle(60000).batch(batchsz)
    ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    ds_val = ds_val.map(preprocess).batch(batchsz)
    
    iteration = iter(db)
    sample = next(iteration)
    print("迭代器获得为:", sample[0].shape, sample[1].shape)
    
    
    class MyDense(layers.Layer):
        def __init__(self, inp_dim, outp_dim):
            super(MyDense, self).__init__()
    
            self.kernel = self.add_variable('w', [inp_dim, outp_dim])
            self.bias = self.add_variable('b', [outp_dim])
    
        def call(self, input, training=None):
            out = input @ self.kernel + self.bias
            return out
    
    class MyModel(keras.Model):
    
        def __init__(self):
            super(MyModel, self).__init__()
    
            self.fc1 = MyDense(28*28, 256)
            self.fc2 = MyDense(256, 128)
            self.fc3 = MyDense(128, 64)
            self.fc4 = MyDense(64, 32)
            self.fc5 = MyDense(32, 10)
    
        def call(self, inputs, training=None):
    
            x = self.fc1(inputs)  ##fc1为一个instance,默认调用__call__()==> call()
            x = tf.nn.relu(x)
            x = self.fc2(x)
            x = tf.nn.relu(x)
            x = self.fc3(x)
            x = tf.nn.relu(x)
            x = self.fc4(x)
            x = tf.nn.relu(x)
            x = self.fc5(x)
    
            return x
    
    network = MyModel()
    
    network.compile(optimizer=optimizers.Adam(lr=0.01),
                    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy']
    )
    
    network.fit(db, epochs=5, validation_data=ds_val,
                validation_freq=2)
    
    
    network.evaluate(ds_val)
    
    sample = next(iter(ds_val))
    x = sample[0]
    y = sample[1] # one-hot
    pred = network.predict(x) # [b, 10]
    # convert back to number
    y = tf.argmax(y, axis=1)  # [b, 1]
    pred = tf.argmax(pred, axis=1)
    
    print(pred)
    print(y)
  • 相关阅读:
    方法引用(method reference)
    函数式接口
    Lambda 表达式
    LinkedList 源码分析
    ArrayList 源码分析
    Junit 学习笔记
    Idea 使用 Junit4 进行单元测试
    Java 定时器
    【干货】Mysql的"事件探查器"-之Mysql-Proxy代理实战一(安装部署与实战sql拦截与性能监控)
    python-flask框架web服务接口开发实例
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14691597.html
Copyright © 2011-2022 走看看