zoukankan      html  css  js  c++  java
  • 【tensorflow】tf.keras + 神经网络类class 6 步搭建神经网络

    tf.keras + Sequential() 可以搭建出上层输入就是下层输出的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构。

    这时候可以选择用类 class 搭建神经网络结构,即使用 class 类封装一个网络结构:

     

    ...

    class MyModel(Model):
      def __init__(self):
        super(MyModel, self).__init__()
        定义网络结构块
      def call(self, x):
        调用网络结构块,实现前向传播
        return y


    model = MyModel()

    ...

     

    可以认为 __init__() 函数准备出了搭建网络所需的各种积木,call() 函数调用积木完成了神经网络的搭建,即实现前向传播,从输入x到输出y。

     

    六步:

    1. import 相关模块。
    2. 指定要喂入网络的训练集和测试集。
    3. class MyModel(Model)   model = MyModel()
    4. 在 compile() 中配置训练方法。
    5. 在 fit() 中执行训练过程。
    6. 用 summary() 打印出网络的结构和参数统计

     

    代码:

    import tensorflow as tf
    from  tensorflow.keras.layers import Dense
    from tensorflow.keras import Model
    from sklearn import datasets
    import numpy as np
    
    # 读取训练用的输入特征和标签
    x_train = datasets.load_iris().data
    y_train = datasets.load_iris().target
    
    # 数据集乱序
    np.random.seed(116)
    np.random.shuffle(x_train)
    np.random.seed(116)
    np.random.shuffle(y_train)
    tf.random.set_seed(116)
    
    # 定义神经网络类
    class IrisModel(Model):
        def __init__(self):
            super(IrisModel, self).__init__()
            # 定义第一层网络结构:含有三个神经元的全连接层
            self.d1 = Dense(3, activation="softmax", kernel_regularizer=tf.keras.regularizers.l2())
    
        def call(self, x):
            # 调用网络结构,实现前向传播
            y = self.d1(x)
            return y
    
    # 声明神经网络模型对象
    model = IrisModel()
    
    # 配置训练方法
    model.compile(optimizer=tf.optimizers.SGD(lr=0.1),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=500,
              validation_split=0.2,
              validation_freq=20)
    
    # 打印网络的结构和参数
    model.summary()

     

  • 相关阅读:
    苹果一体机发射Wi-Fi
    iphone 屏蔽系统自动更新,消除设置上的小红点
    data parameter is nil 异常处理
    copy与mutableCopy的区别总结
    java axis2 webservice
    mysql 远程 ip访问
    mysql 存储过程小问题
    mysql游标错误
    is not writable or has an invalid setter method错误的解决
    Struts2中关于"There is no Action mapped for namespace / and action name"的总结
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13524023.html
Copyright © 2011-2022 走看看