zoukankan      html  css  js  c++  java
  • 使用tf.keras API 构建神经网络(基础)

    tf2.0推荐的模型搭建方法是:

    1. 继承tf.keras.Model类,进行扩展以定义自己的新模型。
    2. 手工编写模型训练、评估模型的流程。

        (优点:灵活度高;与其他深度学习框架共通)

    以CNN处理单通道图片作为示例:

    class CNN(tf.keras.Model):
        def __init__(self): #定义类的构造方法(这里是初始化预定义好的网络结构)
            super().__init__() #这个类是继承tf.keras.Model类,因此执行父类的初始化
            self.conv1 = tf.keras.layers.Conv2D(
                filters=32,             # 卷积层神经元(卷积核)数目
                kernel_size=[5, 5],     # 感受野大小
                padding='same',         # padding策略(vaild 或 same)
                activation=tf.nn.relu   # 激活函数
            )
            self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
            self.conv2 = tf.keras.layers.Conv2D(
                filters=64,
                kernel_size=[5, 5],
                padding='same',
                activation=tf.nn.relu
            )
            self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
            self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))
            self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu)
            self.dense2 = tf.keras.layers.Dense(units=10)
     
        def call(self, inputs):
            x = self.conv1(inputs)                  # [batch_size, 28, 28, 32]
            x = self.pool1(x)                       # [batch_size, 14, 14, 32]
            x = self.conv2(x)                       # [batch_size, 14, 14, 64]
            x = self.pool2(x)                       # [batch_size, 7, 7, 64]
            x = self.flatten(x)                     # [batch_size, 7 * 7 * 64]
            x = self.dense1(x)                      # [batch_size, 1024]
            x = self.dense2(x)                      # [batch_size, 10]
            output = tf.nn.softmax(x)
            return output

    下面解释一下这种网络构建方法:

    1. 我们定义了一个类CNN来继承tf.keras.Model类,目的是为了相较于原类能够有更多自定义的方法,更灵活
    2. 自定义的类中,首先在__init__中定义类的构造方法。构造方法中我们定义了模型中的各个层、以及对各个层的参数赋值(将tf.keras.layers中包装的‘层’实例化)。(建议定义的顺序按照设计的CNN网络架构的顺序排列,便于理解)
    3. 定义一个call方法,一个类只要实现了call方法,这个类的实例就可以用函数一样的形式进行调用,如CNN_obj = CNN(); CNN_obj()这种形式,并可以向其传递参数。
    4. 在我们自定义的类中,call方法要接受训练数据的特征,特征在定义的层中顺序传递,最后输出预测值,用于后续计算。
  • 相关阅读:
    python调用WebService遇到的问题'Document' object has no attribute 'set'
    jquery AJAX 拦截器 success error
    js 钩子(hook)
    js 继承
    js Object的复制
    js关于 indexOf
    js重排序,笔记
    js类型检测,笔记
    jquery源码的阅读理解
    Windows IPC 连接详解(转)
  • 原文地址:https://www.cnblogs.com/mx0813/p/12622765.html
Copyright © 2011-2022 走看看