zoukankan      html  css  js  c++  java
  • Tensorflow05-简单的全连接神经网络案例

    第一个神经元网络就使用最简单的全连接神经网络。

    使用tensorflow里的 fashion_mnist 服饰数据集 来完成此次的入门案例,建议使用 jupyter 分步执行,每步都理解掌握。

    数据集介绍:大概60000张图片,分成了衣服帽子鞋子等等10个类别。每张图片是由 28*28 个像素组成的,每个像素取值 0 ~ 255。

    import tensorflow as tf
    from tensorflow import keras
    import matplotlib.pyplot as plt
    
    
    # 加载数据集
    fashion_mnist = keras.datasets.fashion_mnist
    # 得到训练/测试 数据,训练/测试 标签
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    
    # 查看数据形状
    train_images.shape, train_labels.shape
    
    plt.imshow(test_images[0])   # 画图用  imshow !
    
    # 创建神经元模型
    model = keras.Sequential()
    # 第一层使用Flatten
    model.add(keras.layers.Flatten(input_shape=(28, 28)))
    model.add(keras.layers.Dense(128, activation=tf.nn.relu))
    model.add(keras.layers.Dense(10, activation=tf.nn.softmax))
    
    # 查看神经网络结构
    model.summary()
    
    #配置训练方法,optimizer(优化器)为 经常使用的Adam,损失函数使用sparse_categorical_crossentropy,注意还有不带sparse的,则表示数据为 独热编码形式的。
    model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['accuracy'])
    
    # 为防止过拟合定义该类
    class myCallback(tf.keras.callbacks.Callback): # 继承自 Callback
        def on_epoch_end(self, epoch, logs={}): # 重写该方法
            if(logs.get('loss') < 0.4): # 如果 loss < 0.4, 认为发生过拟合
                print("
    Loss is low so cancelling training")
                self.model.stop_training = True # 停止训练
    
    callbacks = myCallback()
    
    # 归一化
    train_images = train_images/255
    test_images_scaled = test_images/255
    
    # 训练数据得到 history 对象,最后一个参数表示自动中止训练,类的定义在上方
    history = model.fit(train_images, train_labels, epochs=5, callbacks=[callbacks])
    
    # 利用 测试数据/测试标签 评估模型
    model.evaluate(test_images_scaled, test_labels)
    
    # 预测数据,并提取第一个(0)的预测结果
    model.predict(test_images_scaled)[0]

     对该案例代码中的一些解释:

    首先这个数据集的每个元素是二维的,即这个数据集存放着若干张图片,每个图片是一个像素 28*28 的二维矩阵存储。

    所以我们的模型第一层使用 Flatten,作用是将二维输入数据转换成一维的。也就是输入层。

    Dense 表示全连接网络,至于参数 激活函数 activation 在上篇博客中有详细解释。

    第二个 Dense 是输出层,一共有 10 个类别,所以输出的神经元个数为 10。这层也叫输出层。

    介于输入输出层之间为 隐含层,这里的隐含层只有一个,也是 Dense,这里神经元数量128,可以自己更改,以得到更好的训练结果。

    配置模型的编译 compile ,优化器为 Adam(),损失函数为 sparse_categorical_crossentropy

    自定义的 Callback 的继承类,防止过拟合。

    fit 训练数据

    evaluate 利用测试集评估模型

    predict 预测数据

  • 相关阅读:
    hdu1069Monkey and Banana(动态规划)
    hdu2571 命运(动态规划)
    hdu1505City Game(动态规划)
    在jvm底层有关于方法区的介绍
    用IDEA查看源码总是跳到.class文件而不是.java文件的解决办法
    如果Son类继承Father类,Father类继承GrandFather类,那么new Son()创建对象的时候是否会执行GrandFather类里面的方法
    看面试题有感:子类构造器(无参或有参)继承的super()方法在何时调用,与静态代码块,普通代码块相比的执行顺序如何的思考及证明
    关于子类构造器调用super()方法的规定
    为什么重写了equals方法后还需要重写hashCode方法
    字符串常量池处在JVM的堆中,那么是在堆的哪个部分呢
  • 原文地址:https://www.cnblogs.com/dongao/p/14380534.html
Copyright © 2011-2022 走看看