zoukankan      html  css  js  c++  java
  • 【tensorflow】使用_Fashion数据集_搭建神经网络模型:Sequential() / 神经网络类class 两种方法

    FASHION 数据集一共有 7 万张图片,每张图片都是 28x28 像素点的灰度值数据,其中 6 万张用于训练,1 万张用于测试。

    一共有 10 个分类:

    0 T恤

    1 裤子

    2 帽头衫

    3 连衣裙

    4 外套

    5 凉鞋

    6 衬衫

    7 运动鞋

    8 包

    9 靴子

    f.keras + Sequential() 详解

    代码:

    import tensorflow as tf
    
    # 读取训练用的输入特征和标签
    fashion = tf.keras.datasets.fashion_mnist
    (x_train, y_train), (x_test, y_test) = fashion.load_data()
    
    # 输入特征归一化,减小计算量,方便神经网络吸收
    x_train, x_test = x_train/255.0, x_test/255.0
    
    # 声明网络结构
    model = tf.keras.models.Sequential([
        # 拉直层
        tf.keras.layers.Flatten(),
        # 两层全连接层
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(10, activation="softmax")
    ])
    
    # 配置训练方法(优化器,损失函数,评测方法)
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=5,
              validation_data=(x_test, y_test),
              validation_freq=1)
    
    # 打印网络结构和参数统计
    model.summary()

    tf.keras + 神经网络类class 详解

    代码:

    import tensorflow as tf
    from tensorflow.keras.layers import Dense, Flatten
    from tensorflow.keras import Model
    
    # 读取训练用的输入特征和标签
    fashion = tf.keras.datasets.fashion_mnist
    (x_train, y_train), (x_test, y_test) = fashion.load_data()
    
    # 输入特征归一化,减小计算量,方便神经网络吸收
    x_train, x_test = x_train/255.0, x_test/255.0
    
    # 定义神经网络类
    class FashionModel(Model):
        # 定义网络结构
        def __init__(self):
            super(FashionModel, self).__init__()
            self.flatten = Flatten()
            self.d1 = Dense(128, activation="relu")
            self.d2 = Dense(10, activation="softmax")
    
        # 调用网络结构,实现前向传播
        def call(self, inputs, training=None, mask=None):
            x = self.flatten(inputs)
            x = self.d1(x)
            y = self.d2(x)
            return y
    
    # 声明神经网络对象
    model = FashionModel()
    
    # 配置训练方法(优化器,损失函数,评测指标)
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 执行训练过程
    model.fit(x_train, y_train,
              batch_size=32, epochs=5,
              validation_data=(x_test, y_test),
              validation_freq=1)
    
    # 打印网络结构和参数
    model.summary()
  • 相关阅读:
    Linux-netstat
    API接口防止参数篡改和重放攻击
    Java中遍历Map的几种方式
    Java泛型中的标记符含义
    Iterator 和 for...of 循环
    Promise 对象
    Reflect
    正则要求密码长度最少12位,包含至少1个特殊字符,2个数字,2个大写字母和一些小写字母。
    一个JS正则表达式,一个正实数,整数部分最多11位 小数部分最多 8位
    java阿里云短信发送配置
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13527573.html
Copyright © 2011-2022 走看看