zoukankan      html  css  js  c++  java
  • tensorflow学习——02Fashion MNIST数据集神经网络训练

    Fashion MMIST数据集是一个现成的数据集,可以直接用来学习深度学习

    其中包括70000张图片,10个类别,28*28像素,用于训练神经元网络

     

     

    上面是一个神经元示意图,有3个输入x1 x2 x3,并且有是三个权重w1 w2 w3,  b是他的截距,也是一个常数

    神经元就把他的输入分别乘以权重加起来,并且加上一个常数,得到一个值z,将这个值放入到激活函数中,函数的输出就是整个神经元的输出

     

    Relu激活函数用于中间层,只有输入是正数的时候才会有输出,输入是负数的时候输出为0

    Softmax函数将输出压缩到0-1之间

     

    #加载Fashion MNIST数据集
    #第一次执行加载数据集的时候会看到有下载的进度条
    import tensorflow as tf
    from tensorflow import keras
    fashion_mnist=keras.datasets.fashion_mnist
    (train_images,train_labels),(test_images,test_labels)=fashion_mnist.load_data()
    #(训练的图片,训练图片的标签),(测试的图片,测试图片的标签)
    print(train_images.shape) #测试图片有60000张,每张图片是28*28像素
    import matplotlib.pyplot as plt
    #如果出现没有matplotlib这个包,可以直接使用conda install matplotlib命令进行下载
    plt.imshow(train_images[0]) #显示训练集的第一张图片
    #构建神经元网络模型
    #三层
    #第一层用于接受输入,每张图片都是28*28,所以shape是28,28
    #第二层是中间层,有128个神经元,这个数字是自己可以任意修改的
    #第三层是输出层,分类类别有10个,所以有10个神经元
    #model=keras.Sequential([
    #    keras.layers.Flatten(input_shape(28,28)),
    #    keras.layers.Dense(128,activation=tf.nn.relu),
    #    keras.layers.Dense(10,activation=tf.nn.softmax)
    #])
    
    model=keras.Sequential() #构建网络模型
    model.add(keras.layers.Flatten(input_shape=(28,28))) #输入层
    model.add(keras.layers.Dense(128,activation=tf.nn.relu)) #加一个中间层
    model.add(keras.layers.Dense(10,activation=tf.nn.softmax)) #加一个输出层
    model.summary() #观察构造的网络模型

    其中78428*28

    100480=(784+1)*128

    1290=(128+1)*10

    输入是28*28=784以及一个bias的截距

    中间层到输出层就是128个神经元加一个bias截距

    上面这是一个全连接的神经网络

    如果想要学习神经网络的理论,可以看下面图片中的这个资料

    自动终止训练

    如果训练次数过渡会出现过拟合的情况

    训练的loss和测试的loss出现分叉的时候一般就是过拟合

     

    #神经元网络模型并不是训练次数越多越好

     

    class myCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self,epoch,logs={}):
            if(logs.get('loss')<0.4): #损失函数的值小于0.4的时候
                print("
    loss is low so cancelling training")
                self.model.stop_training=True  #终止训练
                
    callbacks=myCallback()
    mnist=tf.keras.datasets.fashion_mnist
    (training_images,training_labels),(test_images,test_labels)=mnist.load_data()
    training_images_scaled=training_images/255
    test_images_scaled=test_images/255.0
    model=tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512,activation=tf.nn.relu),
        tf.keras.layers.Dense(10,activation=tf.nn.softmax)
    ])
    model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=['accuracy'])
    model.fit(training_images_scaled,training_labels,epochs=5,callbacks=[callbacks])

     

     

     

     

     

     


    作者:孙建钊
    出处:http://www.cnblogs.com/sunjianzhao/
    本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

  • 相关阅读:
    高可用性机制
    Moodle课程资源系统安装
    Windows 10 安装 chocolatey
    centos7安装samba服务器
    抽签网页板代码
    CentOS7系统操作httpd服务
    centos7.2下放行端口
    centos7没有netstat命令的解决办法
    Linux
    Linux下常用服务的端口号超详细整理
  • 原文地址:https://www.cnblogs.com/sunjianzhao/p/15404137.html
Copyright © 2011-2022 走看看