zoukankan      html  css  js  c++  java
  • 【tensorflow】搭建_手写数字识别_神经网络模型:Sequential() / 神经网络类class 两种方法

    MNIST 数据集一共有 7 万张图片,都是 28x28 像素点的 0~9 手写数字,其中 6 万用于训练,1 万张用于测试。

    f.keras + Sequential() 详解

    代码:

    import tensorflow as tf
    
    # 读入训练所需的输入特征和标签
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    # 输入特征归一化,减小计算量,方便神经网络吸收
    x_train, x_test = x_train/255.0, x_test/255.0
    
    # 搭建网络
    model = tf.keras.models.Sequential([
        # 将输入特征(28x28)拉直为一维数组(1x748)
        tf.keras.layers.Flatten(),
        # 定义第一层网络,有128个神经元
        tf.keras.layers.Dense(128, activation="relu"),
        # 定义第二层网络,有10个神经元
        tf.keras.layers.Dense(10, activation="softmax")
    ])
    
    # 配置训练方法
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=5,
              validation_data=(x_test, y_test),
              validation_freq=1)
    
    # 打印出网络结构和参数统计
    model.summary()

    tf.keras + 神经网络类class 详解

    代码:

    import tensorflow as tf
    from tensorflow.keras.layers import Dense, Flatten
    from tensorflow.keras import Model
    
    # 读取训练用的输入特征和标签
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    # 输入特征归一化,减小计算量,方便神经网络吸收
    x_train, x_test = x_train/255.0, x_test/255.0
    
    # 定义神经网络类
    class MnistModel(Model):
        def __init__(self):
            super(MnistModel, self).__init__()
            # 定义拉直层
            self.flatten = Flatten()
            # 定义第一层神经网络
            self.d1 = Dense(128, activation="relu")
            # 定义第二层神经网络
            self.d2 = Dense(10, activation="softmax")
    
        def call(self, x):
            # 将输入特征拉直成一维数组
            x = self.flatten(x)
            # 调用剩下两层神经网络,实现前向传播
            x = self.d1(x)
            y = self.d2(x)
            return y
    
    # 声明神经网络对象
    model = MnistModel()
    
    # 配置训练方法
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=5,
              validation_data=(x_test, y_test),
              validation_freq=1)
    
    # 打印网络结构和参数统计
    model.summary()
  • 相关阅读:
    MSSQL·阻止保存要求重新创建表的更改配置
    MSSQL·查询某数据库中所有表的记录数并排序
    异常处理·psftp·local unable to open
    MSSQL·Execution Timeout Expired. The timeout period elapsed prior to completion of the oper..
    MSSQL·ORDER BY 1 DESC是什么写法?
    MSSQL·大数据量历史数据清理的思路
    ubuntu清理wine卸载后的残余项目
    Learning the Vi Editor, 6th Edition O'Reilly Media
    做一粒不浮躁的好“种子”
    Qt Designer使用简易教程
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13524576.html
Copyright © 2011-2022 走看看