zoukankan      html  css  js  c++  java
  • 【tensorflow】搭建_手写数字识别_神经网络模型:Sequential() / 神经网络类class 两种方法

    MNIST 数据集一共有 7 万张图片,都是 28x28 像素点的 0~9 手写数字,其中 6 万用于训练,1 万张用于测试。

    f.keras + Sequential() 详解

    代码:

    import tensorflow as tf
    
    # 读入训练所需的输入特征和标签
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    # 输入特征归一化,减小计算量,方便神经网络吸收
    x_train, x_test = x_train/255.0, x_test/255.0
    
    # 搭建网络
    model = tf.keras.models.Sequential([
        # 将输入特征(28x28)拉直为一维数组(1x748)
        tf.keras.layers.Flatten(),
        # 定义第一层网络,有128个神经元
        tf.keras.layers.Dense(128, activation="relu"),
        # 定义第二层网络,有10个神经元
        tf.keras.layers.Dense(10, activation="softmax")
    ])
    
    # 配置训练方法
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=5,
              validation_data=(x_test, y_test),
              validation_freq=1)
    
    # 打印出网络结构和参数统计
    model.summary()

    tf.keras + 神经网络类class 详解

    代码:

    import tensorflow as tf
    from tensorflow.keras.layers import Dense, Flatten
    from tensorflow.keras import Model
    
    # 读取训练用的输入特征和标签
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    # 输入特征归一化,减小计算量,方便神经网络吸收
    x_train, x_test = x_train/255.0, x_test/255.0
    
    # 定义神经网络类
    class MnistModel(Model):
        def __init__(self):
            super(MnistModel, self).__init__()
            # 定义拉直层
            self.flatten = Flatten()
            # 定义第一层神经网络
            self.d1 = Dense(128, activation="relu")
            # 定义第二层神经网络
            self.d2 = Dense(10, activation="softmax")
    
        def call(self, x):
            # 将输入特征拉直成一维数组
            x = self.flatten(x)
            # 调用剩下两层神经网络,实现前向传播
            x = self.d1(x)
            y = self.d2(x)
            return y
    
    # 声明神经网络对象
    model = MnistModel()
    
    # 配置训练方法
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=5,
              validation_data=(x_test, y_test),
              validation_freq=1)
    
    # 打印网络结构和参数统计
    model.summary()
  • 相关阅读:
    有效的字母异位词---简单
    字符串中的第一个唯一字符---简单
    整数反转---简单
    使用httpServlet方法开发
    servlet生命周期
    tomcat_user文件的1配置
    selvlet入门自己部署(sevlet接口实现)
    数据库mysql实战
    tomcat的熟悉目录结构
    虚拟主机
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13524576.html
Copyright © 2011-2022 走看看