zoukankan      html  css  js  c++  java
  • TensorFlow2_200729系列---18、手写数字识别(层方式)

    TensorFlow2_200729系列---18、手写数字识别(层方式)

    一、总结

    一句话总结:

    之前是张量(tensor)的方式,体现细节和原理,现在是层方式,更加简便简洁
    model = Sequential([
        layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]
        layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]
        layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]
        layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]
        layers.Dense(10) # [b, 32] => [b, 10], 330 = 32*10 + 10
    ])
    model.build(input_shape=[None, 28*28])
    model.summary()
    
    使用模型:
    logits = model(x)

    1、tensorflow的keras模块包括datasets, layers, optimizers, Sequential, metrics?

    from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

    二、手写数字识别(层方式)

    博客对应课程的视频位置:

    import  os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    import tensorflow as tf
    from    tensorflow import keras
    from    tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
    
    assert tf.__version__.startswith('2.')
    
    # 预处理函数
    # 数据归一化
    def preprocess(x, y):
    
        x = tf.cast(x, dtype=tf.float32) / 255.
        y = tf.cast(y, dtype=tf.int32)
        return x,y
    
    # 自动加载数据
    (x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
    print(x.shape, y.shape)
    
    # batch size:一批的大小
    batchsz = 128
    
    # 训练数据
    # Creates a `Dataset` whose elements are slices of the given tensors.
    db = tf.data.Dataset.from_tensor_slices((x,y))
    # shuffle打乱并且分batch
    db = db.map(preprocess).shuffle(10000).batch(batchsz)
    
    # 测试数据
    db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
    # 测试数据不需要打乱
    db_test = db_test.map(preprocess).batch(batchsz)
    
    # 迭代器
    db_iter = iter(db)
    sample = next(db_iter)
    print('batch:', sample[0].shape, sample[1].shape)
    # batch: (128, 28, 28) (128,)
    
    model = Sequential([
        layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]
        layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]
        layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]
        layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]
        layers.Dense(10) # [b, 32] => [b, 10], 330 = 32*10 + 10
    ])
    model.build(input_shape=[None, 28*28])
    model.summary()
    # w = w - lr*grad
    optimizer = optimizers.Adam(lr=1e-3)
    
    def main():
    
    
        for epoch in range(30):
    
    
            for step, (x,y) in enumerate(db):
    
                # x: [b, 28, 28] => [b, 784]
                # y: [b]
                x = tf.reshape(x, [-1, 28*28])
    
                with tf.GradientTape() as tape:
                    # [b, 784] => [b, 10]
                    logits = model(x)
                    y_onehot = tf.one_hot(y, depth=10)
                    # [b]
                    loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
                    loss_ce = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                    loss_ce = tf.reduce_mean(loss_ce)
    
                # model.trainable_variables:表示参数,也就是w和b    
                grads = tape.gradient(loss_ce, model.trainable_variables)
                # 原地更新
                optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    
                if step % 100 == 0:
                    print(epoch, step, 'loss:', float(loss_ce), float(loss_mse))
    
    
            # test
            total_correct = 0
            total_num = 0
            for x,y in db_test:
    
                # x: [b, 28, 28] => [b, 784]
                # y: [b]
                x = tf.reshape(x, [-1, 28*28])
                # [b, 10]
                logits = model(x)
                # logits => prob, [b, 10]
                prob = tf.nn.softmax(logits, axis=1)
                # [b, 10] => [b], int64
                pred = tf.argmax(prob, axis=1)
                pred = tf.cast(pred, dtype=tf.int32)
                # pred:[b]
                # y: [b]
                # correct: [b], True: equal, False: not equal
                correct = tf.equal(pred, y)
                correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
    
                total_correct += int(correct)
                total_num += x.shape[0]
    
            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)
    
    
    if __name__ == '__main__':
        main()
    (60000, 28, 28) (60000,)
    batch: (128, 28, 28) (128,)
    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense (Dense)                multiple                  200960    
    _________________________________________________________________
    dense_1 (Dense)              multiple                  32896     
    _________________________________________________________________
    dense_2 (Dense)              multiple                  8256      
    _________________________________________________________________
    dense_3 (Dense)              multiple                  2080      
    _________________________________________________________________
    dense_4 (Dense)              multiple                  330       
    =================================================================
    Total params: 244,522
    Trainable params: 244,522
    Non-trainable params: 0
    _________________________________________________________________
    0 0 loss: 2.2831406593322754 0.12128783017396927
    0 100 loss: 0.6292909383773804 18.471372604370117
    0 200 loss: 0.4110489785671234 18.012493133544922
    0 300 loss: 0.4045095443725586 14.464678764343262
    0 400 loss: 0.5457848310470581 18.89963722229004
    0 test acc: 0.8464
    1 0 loss: 0.3746733069419861 17.479198455810547
    1 100 loss: 0.4358266592025757 22.785991668701172
    1 200 loss: 0.34541454911231995 21.296964645385742
    1 300 loss: 0.6067743301391602 15.614158630371094
    1 400 loss: 0.3972930908203125 19.6492862701416
    1 test acc: 0.8624
    2 0 loss: 0.38962411880493164 21.97345733642578
    2 100 loss: 0.32657697796821594 24.56244659423828
    2 200 loss: 0.33351215720176697 21.94455909729004
    2 300 loss: 0.33902692794799805 26.513694763183594
    2 400 loss: 0.374667763710022 24.564464569091797
    2 test acc: 0.87
    3 0 loss: 0.3156976103782654 26.225505828857422
    3 100 loss: 0.25029417872428894 31.27167320251465
    3 200 loss: 0.2644597291946411 30.24835205078125
    3 300 loss: 0.27939510345458984 28.048099517822266
    3 400 loss: 0.28537383675575256 29.133506774902344
    3 test acc: 0.8732
    4 0 loss: 0.3248022496700287 28.466659545898438
    4 100 loss: 0.2858565151691437 30.900304794311523
    4 200 loss: 0.27638182044029236 29.498920440673828
    4 300 loss: 0.230974480509758 30.112049102783203
    4 400 loss: 0.2854992747306824 30.296920776367188
    4 test acc: 0.8702
    5 0 loss: 0.3455154001712799 40.34062194824219
    5 100 loss: 0.1356138437986374 44.47637939453125
    5 200 loss: 0.35171258449554443 36.933929443359375
    5 300 loss: 0.1946554183959961 40.9914436340332
    5 400 loss: 0.317558228969574 43.98896026611328
    5 test acc: 0.8766
    6 0 loss: 0.220485657453537 43.64817810058594
    6 100 loss: 0.21728813648223877 39.432037353515625
    6 200 loss: 0.1949092447757721 48.853843688964844
    6 300 loss: 0.15513527393341064 44.78699493408203
    6 400 loss: 0.21645382046699524 49.537960052490234
    6 test acc: 0.8784
    7 0 loss: 0.3043939471244812 43.625526428222656
    7 100 loss: 0.3803304433822632 47.13135528564453
    7 200 loss: 0.19133441150188446 45.429298400878906
    7 300 loss: 0.1776151806116104 41.924537658691406
    7 400 loss: 0.16863960027694702 42.79949951171875
    7 test acc: 0.8738
    8 0 loss: 0.1874811351299286 44.651611328125
    8 100 loss: 0.18167194724082947 51.90435028076172
    8 200 loss: 0.20349520444869995 49.169429779052734
    8 300 loss: 0.2610611915588379 48.555458068847656
    8 400 loss: 0.2457474023103714 55.176734924316406
    8 test acc: 0.8802
    9 0 loss: 0.2630460858345032 53.88508224487305
    9 100 loss: 0.20558330416679382 63.04207992553711
    9 200 loss: 0.18517211079597473 65.81611633300781
    9 300 loss: 0.20496012270450592 55.272369384765625
    9 400 loss: 0.22070546448230743 59.18791198730469
    9 test acc: 0.8839
    10 0 loss: 0.15226173400878906 50.21383285522461
    10 100 loss: 0.10652273893356323 66.88746643066406
    10 200 loss: 0.15289798378944397 69.40467071533203
    10 300 loss: 0.24505120515823364 52.915138244628906
    10 400 loss: 0.21931703388690948 62.56816101074219
    10 test acc: 0.8879
    11 0 loss: 0.28399163484573364 66.48797607421875
    11 100 loss: 0.2144084870815277 68.24934387207031
    11 200 loss: 0.2513824999332428 49.95161056518555
    11 300 loss: 0.23070569336414337 53.46424102783203
    11 400 loss: 0.1969212144613266 59.240577697753906
    11 test acc: 0.8885
    12 0 loss: 0.1924872249364853 55.86700439453125
    12 100 loss: 0.21166521310806274 69.66253662109375
    12 200 loss: 0.09095969796180725 76.57353210449219
    12 300 loss: 0.15812699496746063 67.44322204589844
    12 400 loss: 0.20802710950374603 84.34611511230469
    12 test acc: 0.8818
    13 0 loss: 0.22456292808055878 63.22731018066406
    13 100 loss: 0.1939781904220581 75.43051147460938
    13 200 loss: 0.3054753839969635 81.66931915283203
    13 300 loss: 0.23840418457984924 63.1304931640625
    13 400 loss: 0.2474619597196579 69.8863525390625
    13 test acc: 0.8885
    14 0 loss: 0.16988956928253174 85.44975280761719
    14 100 loss: 0.2409886121749878 94.68711853027344
    14 200 loss: 0.19825829565525055 75.6685791015625
    14 300 loss: 0.26892679929733276 102.5340576171875
    14 400 loss: 0.10896225273609161 90.99492645263672
    14 test acc: 0.8929
    15 0 loss: 0.20140430331230164 88.07635498046875
    15 100 loss: 0.11349458247423172 94.00519561767578
    15 200 loss: 0.1777578443288803 75.71000671386719
    15 300 loss: 0.27039724588394165 74.60504150390625
    15 400 loss: 0.2390979528427124 95.77617645263672
    15 test acc: 0.8875
    16 0 loss: 0.22253449261188507 74.66510009765625
    16 100 loss: 0.15607573091983795 91.78387451171875
    16 200 loss: 0.15405383706092834 109.38969421386719
    16 300 loss: 0.10432792454957962 93.83760070800781
    16 400 loss: 0.127157062292099 81.89163208007812
    16 test acc: 0.8902
    17 0 loss: 0.1748599112033844 74.17550659179688
    17 100 loss: 0.21128180623054504 101.88328552246094
    17 200 loss: 0.213323712348938 99.44528198242188
    17 300 loss: 0.1905888170003891 92.34651947021484
    17 400 loss: 0.08545292168855667 118.2466049194336
    17 test acc: 0.892
    18 0 loss: 0.13534505665302277 107.93522644042969
    18 100 loss: 0.10933603346347809 120.73545837402344
    18 200 loss: 0.21846728026866913 107.94190979003906
    18 300 loss: 0.2655482292175293 107.08270263671875
    18 400 loss: 0.23332582414150238 110.47785949707031
    18 test acc: 0.892
    19 0 loss: 0.16872575879096985 112.55984497070312
    19 100 loss: 0.2029556930065155 105.87848663330078
    19 200 loss: 0.13815325498580933 110.57797241210938
    19 300 loss: 0.26082828640937805 106.12140655517578
    19 400 loss: 0.15341421961784363 129.8838348388672
    19 test acc: 0.8934
    20 0 loss: 0.29162901639938354 111.24371337890625
    20 100 loss: 0.23025716841220856 105.3729248046875
    20 200 loss: 0.13770082592964172 119.57967376708984
    20 300 loss: 0.24651116132736206 120.30937957763672
    20 400 loss: 0.21254345774650574 107.91622161865234
    20 test acc: 0.8917
    21 0 loss: 0.09702333062887192 100.24187469482422
    21 100 loss: 0.15910854935646057 129.3473358154297
    21 200 loss: 0.0851014256477356 138.20169067382812
    21 300 loss: 0.1595071405172348 118.91499328613281
    21 400 loss: 0.2024853229522705 109.33180236816406
    21 test acc: 0.8899
    22 0 loss: 0.13461339473724365 140.96359252929688
    22 100 loss: 0.18555812537670135 163.75796508789062
    22 200 loss: 0.20990914106369019 129.61654663085938
    22 300 loss: 0.13982388377189636 127.89125061035156
    22 400 loss: 0.15919993817806244 130.1854248046875
    22 test acc: 0.8945
    23 0 loss: 0.1364736706018448 139.3319549560547
    23 100 loss: 0.15799979865550995 164.3890380859375
    23 200 loss: 0.13190290331840515 147.40843200683594
    23 300 loss: 0.21002498269081116 128.39404296875
    23 400 loss: 0.20846235752105713 150.4744873046875
    23 test acc: 0.8977
    24 0 loss: 0.180350199341774 148.44241333007812
    24 100 loss: 0.13309326767921448 141.54718017578125
    24 200 loss: 0.09000922739505768 144.95884704589844
    24 300 loss: 0.09340814501047134 149.47250366210938
    24 400 loss: 0.11350023001432419 135.44602966308594
    24 test acc: 0.8948
    25 0 loss: 0.10290056467056274 129.1949920654297
    25 100 loss: 0.10859610140323639 147.93728637695312
    25 200 loss: 0.15649116039276123 157.18661499023438
    25 300 loss: 0.09786863625049591 163.36807250976562
    25 400 loss: 0.13727730512619019 151.49111938476562
    25 test acc: 0.8928
    26 0 loss: 0.14575082063674927 152.08493041992188
    26 100 loss: 0.08358915150165558 157.06666564941406
    
    26 200 loss: 0.13103337585926056 141.79519653320312
    26 300 loss: 0.1875842809677124 163.2027587890625
    26 400 loss: 0.2265387624502182 177.1208038330078
    26 test acc: 0.896
    27 0 loss: 0.12106870114803314 160.7439422607422
    27 100 loss: 0.11055881530046463 181.66207885742188
    27 200 loss: 0.08392684161663055 155.105712890625
    27 300 loss: 0.14919903874397278 153.12997436523438
    27 400 loss: 0.11113433539867401 176.20297241210938
    27 test acc: 0.8837
    28 0 loss: 0.20585989952087402 201.09121704101562
    28 100 loss: 0.1687045842409134 180.786865234375
    28 200 loss: 0.13319478929042816 169.2735595703125
    28 300 loss: 0.08168964087963104 159.86993408203125
    28 400 loss: 0.11371222138404846 196.85433959960938
    28 test acc: 0.8918
    29 0 loss: 0.120663121342659 209.4883575439453
    29 100 loss: 0.12449634820222855 171.43267822265625
    29 200 loss: 0.18858373165130615 171.96971130371094
    29 300 loss: 0.07583779096603394 193.2484130859375
    29 400 loss: 0.10753950476646423 199.6551513671875
    29 test acc: 0.8944
     
    我的旨在学过的东西不再忘记(主要使用艾宾浩斯遗忘曲线算法及其它智能学习复习算法)的偏公益性质的完全免费的编程视频学习网站: fanrenyi.com;有各种前端、后端、算法、大数据、人工智能等课程。
    博主25岁,前端后端算法大数据人工智能都有兴趣。
    大家有啥都可以加博主联系方式(qq404006308,微信fan404006308)互相交流。工作、生活、心境,可以互相启迪。
    聊技术,交朋友,修心境,qq404006308,微信fan404006308
    26岁,真心找女朋友,非诚勿扰,微信fan404006308,qq404006308
    人工智能群:939687837

    作者相关推荐

  • 相关阅读:
    ‘随意’不是个好词,‘用心’才是
    servlet
    tomcat服务器
    http协议
    jdbc(Java数据库连接)
    dbcp和druid(数据库连接池)
    关于GitHub
    冒泡和递归
    python内置函数
    python四
  • 原文地址:https://www.cnblogs.com/Renyi-Fan/p/13439979.html
Copyright © 2011-2022 走看看