zoukankan      html  css  js  c++  java
  • Tensorflow函数式API的使用

    Tensorflow函数式API的使用

    一、总结

    一句话总结:

    I、在我们使用tensorflow时,如果不能使用函数式api进行编程,那么一些复杂的神经网络结构就不会实现出来,只能使用简单的单向模型进行一层一层地堆叠。
    II、如果稍微复杂一点,遇到了Resnet这种带有残差模块的神经网络,那么用简单的神经网络堆叠的方式则不可能把这种网络堆叠出来。
    搭建全连接神经网络:
    
    input=keras.Input(shape=(28,28))
    x=keras.layers.Flatten()(input)#调用input
    x=keras.layers.Dense(32,activation="relu")(x)
    x=keras.layers.Dropout(0.5)(x)#一层一层的进行调用上一层的结果
    output=keras.layers.Dense(10,activation="softmax")(x)
    model=keras.Model(inputs=input,outputs=output)
    model.summary()

    二、Tensorflow函数式API的使用

    转自或参考:Tensorflow函数式API的使用
    https://www.cnblogs.com/geeksongs/p/13204568.html

    在我们使用tensorflow时,如果不能使用函数式api进行编程,那么一些复杂的神经网络结构就不会实现出来,只能使用简单的单向模型进行一层一层地堆叠。如果稍微复杂一点,遇到了Resnet这种带有残差模块的神经网络,那么用简单的神经网络堆叠的方式则不可能把这种网络堆叠出来。下面我们来使用函数式API来编写一个简单的全连接神经网络:
    首先导包:

    from tensorflow import keras
    import tensorflow as tf
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt

    导入图片数据集:mnist

    (train_image,train_label),(test_image,test_label)=tf.keras.datasets.fashion_mnist.load_data()

    归一化:

    train_image=train_image/255
    test_image=test_image/255#进行数据的归一化,加快计算的进程

    搭建全连接神经网络:

    input=keras.Input(shape=(28,28))
    x=keras.layers.Flatten()(input)#调用input
    x=keras.layers.Dense(32,activation="relu")(x)
    x=keras.layers.Dropout(0.5)(x)#一层一层的进行调用上一层的结果
    output=keras.layers.Dense(10,activation="softmax")(x)
    model=keras.Model(inputs=input,outputs=output)
    model.summary()

    输出:

    Model: "model"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_1 (InputLayer)         [(None, 28, 28)]          0         
    _________________________________________________________________
    flatten (Flatten)            (None, 784)               0         
    _________________________________________________________________
    dense (Dense)                (None, 32)                25120     
    _________________________________________________________________
    dropout (Dropout)            (None, 32)                0         
    _________________________________________________________________
    dense_1 (Dense)              (None, 10)                330       
    =================================================================
    Total params: 25,450
    Trainable params: 25,450
    Non-trainable params: 0

    拟合模型:

    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                  loss="sparse_categorical_crossentropy",
                  metrics=['acc']
    )
    history=model.fit(train_image,
                      train_label,
                      epochs=15,
                      validation_data=(test_image,test_label))

    输出:

    Train on 60000 samples, validate on 10000 samples
    Epoch 1/15
    60000/60000 [==============================] - 4s 64us/sample - loss: 0.8931 - acc: 0.6737 - val_loss: 0.5185 - val_acc: 0.8160
    Epoch 2/15
    60000/60000 [==============================] - 3s 57us/sample - loss: 0.6757 - acc: 0.7508 - val_loss: 0.4805 - val_acc: 0.8230
    Epoch 3/15
    60000/60000 [==============================] - 3s 50us/sample - loss: 0.6336 - acc: 0.7647 - val_loss: 0.4587 - val_acc: 0.8369
    Epoch 4/15
    60000/60000 [==============================] - 3s 49us/sample - loss: 0.6174 - acc: 0.7689 - val_loss: 0.4712 - val_acc: 0.8294
    Epoch 5/15
    60000/60000 [==============================] - 3s 48us/sample - loss: 0.6080 - acc: 0.7732 - val_loss: 0.4511 - val_acc: 0.8404
    Epoch 6/15
    60000/60000 [==============================] - 3s 48us/sample - loss: 0.5932 - acc: 0.7773 - val_loss: 0.4545 - val_acc: 0.8407
    Epoch 7/15
    60000/60000 [==============================] - 3s 47us/sample - loss: 0.5886 - acc: 0.7772 - val_loss: 0.4394 - val_acc: 0.8428
    Epoch 8/15
    60000/60000 [==============================] - 3s 52us/sample - loss: 0.5820 - acc: 0.7788 - val_loss: 0.4338 - val_acc: 0.8506
    Epoch 9/15
    60000/60000 [==============================] - 3s 48us/sample - loss: 0.5742 - acc: 0.7839 - val_loss: 0.4393 - val_acc: 0.8454
    Epoch 10/15
    60000/60000 [==============================] - 3s 49us/sample - loss: 0.5713 - acc: 0.7847 - val_loss: 0.4422 - val_acc: 0.8477
    Epoch 11/15
    60000/60000 [==============================] - 3s 47us/sample - loss: 0.5642 - acc: 0.7858 - val_loss: 0.4325 - val_acc: 0.8488
    Epoch 12/15
    60000/60000 [==============================] - 3s 48us/sample - loss: 0.5582 - acc: 0.7873 - val_loss: 0.4294 - val_acc: 0.8492
    Epoch 13/15
    60000/60000 [==============================] - 3s 48us/sample - loss: 0.5574 - acc: 0.7882 - val_loss: 0.4263 - val_acc: 0.8523
    Epoch 14/15
    60000/60000 [==============================] - 3s 48us/sample - loss: 0.5524 - acc: 0.7888 - val_loss: 0.4350 - val_acc: 0.8448
    Epoch 15/15
    60000/60000 [==============================] - 3s 47us/sample - loss: 0.5486 - acc: 0.7901 - val_loss: 0.4297 - val_acc: 0.8493

    最后验证集的精度达到了84%,这是一个仅仅使用全连接神经网络和softmax就能够得到的一个很不错的结果了!

     
    我的旨在学过的东西不再忘记(主要使用艾宾浩斯遗忘曲线算法及其它智能学习复习算法)的偏公益性质的完全免费的编程视频学习网站: fanrenyi.com;有各种前端、后端、算法、大数据、人工智能等课程。
    博主25岁,前端后端算法大数据人工智能都有兴趣。
    大家有啥都可以加博主联系方式(qq404006308,微信fan404006308)互相交流。工作、生活、心境,可以互相启迪。
    聊技术,交朋友,修心境,qq404006308,微信fan404006308
    26岁,真心找女朋友,非诚勿扰,微信fan404006308,qq404006308
    人工智能群:939687837

    作者相关推荐

  • 相关阅读:
    ubuntu frp 自编译。本文不能按顺序来 请自己理解
    油猴子 自改脚本 删除页面 div 上下翻页 视频页内全屏 右键可用
    批处理bat 删除指定文件夹下的文件及文件夹
    LUA 静态库 动态库 LD_LIBRARY_PATH 动态库的查找路径 GCC “-l”参数
    delphi 判断奇数偶数
    sf.net
    cmake指定mingw编译器的方法
    关闭delphi ide皮肤
    arch pacman被删除 重装
    delphi 匿名方法访问var参数
  • 原文地址:https://www.cnblogs.com/Renyi-Fan/p/13394895.html
Copyright © 2011-2022 走看看