zoukankan      html  css  js  c++  java
  • 【tensorflow】tf.keras + 神经网络类class 6 步搭建神经网络

    tf.keras + Sequential() 可以搭建出上层输入就是下层输出的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构。

    这时候可以选择用类 class 搭建神经网络结构,即使用 class 类封装一个网络结构:

     

    ...

    class MyModel(Model):
      def __init__(self):
        super(MyModel, self).__init__()
        定义网络结构块
      def call(self, x):
        调用网络结构块,实现前向传播
        return y


    model = MyModel()

    ...

     

    可以认为 __init__() 函数准备出了搭建网络所需的各种积木,call() 函数调用积木完成了神经网络的搭建,即实现前向传播,从输入x到输出y。

     

    六步:

    1. import 相关模块。
    2. 指定要喂入网络的训练集和测试集。
    3. class MyModel(Model)   model = MyModel()
    4. 在 compile() 中配置训练方法。
    5. 在 fit() 中执行训练过程。
    6. 用 summary() 打印出网络的结构和参数统计

     

    代码:

    import tensorflow as tf
    from  tensorflow.keras.layers import Dense
    from tensorflow.keras import Model
    from sklearn import datasets
    import numpy as np
    
    # 读取训练用的输入特征和标签
    x_train = datasets.load_iris().data
    y_train = datasets.load_iris().target
    
    # 数据集乱序
    np.random.seed(116)
    np.random.shuffle(x_train)
    np.random.seed(116)
    np.random.shuffle(y_train)
    tf.random.set_seed(116)
    
    # 定义神经网络类
    class IrisModel(Model):
        def __init__(self):
            super(IrisModel, self).__init__()
            # 定义第一层网络结构:含有三个神经元的全连接层
            self.d1 = Dense(3, activation="softmax", kernel_regularizer=tf.keras.regularizers.l2())
    
        def call(self, x):
            # 调用网络结构,实现前向传播
            y = self.d1(x)
            return y
    
    # 声明神经网络模型对象
    model = IrisModel()
    
    # 配置训练方法
    model.compile(optimizer=tf.optimizers.SGD(lr=0.1),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=500,
              validation_split=0.2,
              validation_freq=20)
    
    # 打印网络的结构和参数
    model.summary()

     

  • 相关阅读:
    java swing学习
    JCheckBox相关知识点
    【python 第五日】 函数闭包与装饰器
    【python第四日】 文件处理 生成器 迭代器
    【Python3 第三日】%和format格式化输出 函数
    【python第二日】运算符 数据类型(数字 字符串 列表 元组 字典 集合) 重新定义比较大小
    怎么设置博客园样式
    【python】第一日 python2和python3区别 命名方式 三种结构
    mybatis-generator.xml
    SpringBoot集成mybatis和mybatis generator
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13524023.html
Copyright © 2011-2022 走看看