zoukankan      html  css  js  c++  java
  • 【tensorflow】tf.keras + 神经网络类class 6 步搭建神经网络

    tf.keras + Sequential() 可以搭建出上层输入就是下层输出的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构。

    这时候可以选择用类 class 搭建神经网络结构,即使用 class 类封装一个网络结构:

     

    ...

    class MyModel(Model):
      def __init__(self):
        super(MyModel, self).__init__()
        定义网络结构块
      def call(self, x):
        调用网络结构块,实现前向传播
        return y


    model = MyModel()

    ...

     

    可以认为 __init__() 函数准备出了搭建网络所需的各种积木,call() 函数调用积木完成了神经网络的搭建,即实现前向传播,从输入x到输出y。

     

    六步:

    1. import 相关模块。
    2. 指定要喂入网络的训练集和测试集。
    3. class MyModel(Model)   model = MyModel()
    4. 在 compile() 中配置训练方法。
    5. 在 fit() 中执行训练过程。
    6. 用 summary() 打印出网络的结构和参数统计

     

    代码:

    import tensorflow as tf
    from  tensorflow.keras.layers import Dense
    from tensorflow.keras import Model
    from sklearn import datasets
    import numpy as np
    
    # 读取训练用的输入特征和标签
    x_train = datasets.load_iris().data
    y_train = datasets.load_iris().target
    
    # 数据集乱序
    np.random.seed(116)
    np.random.shuffle(x_train)
    np.random.seed(116)
    np.random.shuffle(y_train)
    tf.random.set_seed(116)
    
    # 定义神经网络类
    class IrisModel(Model):
        def __init__(self):
            super(IrisModel, self).__init__()
            # 定义第一层网络结构:含有三个神经元的全连接层
            self.d1 = Dense(3, activation="softmax", kernel_regularizer=tf.keras.regularizers.l2())
    
        def call(self, x):
            # 调用网络结构,实现前向传播
            y = self.d1(x)
            return y
    
    # 声明神经网络模型对象
    model = IrisModel()
    
    # 配置训练方法
    model.compile(optimizer=tf.optimizers.SGD(lr=0.1),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=500,
              validation_split=0.2,
              validation_freq=20)
    
    # 打印网络的结构和参数
    model.summary()

     

  • 相关阅读:
    简单理解Socket
    进程间8种通信方式详解
    底部漂浮DIV
    Table样式
    QQ授权登录
    C#_批量插入数据到Sqlserver中的四种方式
    Asp.Net_单点登录
    html之meta详解
    程序员常用工具
    工厂模式理解
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13524023.html
Copyright © 2011-2022 走看看