zoukankan      html  css  js  c++  java
  • 手写数字问题

    import  os
    os.environ['TF_CPP_MIN_LOG_LEVEL']='2'      #使tensorflow少打印一些不必要的信息
    
    import  tensorflow.compat.v1 as tf
    from    tensorflow import keras
    from    tensorflow.keras import layers, optimizers, datasets
    tf.enable_eager_execution() #保证sess.run()能够正常运行
    
    #数据集加载
    (x, y), (x_val, y_val) = datasets.mnist.load_data()
    x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    y = tf.convert_to_tensor(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    print(x.shape, y.shape)
    train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
    train_dataset = train_dataset.batch(200)       #batch为200表示一次加载200张的图片
    
     
    
    #降维    Dense是全连接
    model = keras.Sequential([ 
        layers.Dense(512, activation='relu'),   #relu是非线性参数
        layers.Dense(256, activation='relu'),
        layers.Dense(10)])
    
    optimizer = optimizers.SGD(learning_rate=0.001)
    
    
    def train_epoch(epoch):
    
        # Step4.loop
        for step, (x, y) in enumerate(train_dataset):     #循环300次     60kb/200等于大概300次
    
    
            with tf.GradientTape() as tape:
                # [b, 28, 28] => [b, 784]
                x = tf.reshape(x, (-1, 28*28))
                # Step1. compute output
                # [b, 784] => [b, 10]
                out = model(x)
                # Step2. compute loss
                loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]
    
            # Step3. optimize and update w1, w2, w3, b1, b2, b3
            grads = tape.gradient(loss, model.trainable_variables)     #grads里包含了对w1,w2,w3和b1,b2,b3的loss对其的求导
            # w' = w - lr * grad
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
            if step % 100 == 0:
                print(epoch, step, 'loss:', loss.numpy())
    
    
    
    def train():
    #对整个数据集迭代次30次
        for epoch in range(30):
    
            train_epoch(epoch)
    
    
    
    
    
    
    if __name__ == '__main__':
        train()

    结果:

    (60000, 28, 28) (60000, 10)
    0 0 loss: 2.1289964
    0 100 loss: 0.96601397
    0 200 loss: 0.8044617
    1 0 loss: 0.65632385
    1 100 loss: 0.71072084
    1 200 loss: 0.6174767
    2 0 loss: 0.53884405
    2 100 loss: 0.61792874
    2 200 loss: 0.53729916
    3 0 loss: 0.48332796
    3 100 loss: 0.5644321
    3 200 loss: 0.48922828
    4 0 loss: 0.44779533
    4 100 loss: 0.5270611
    4 200 loss: 0.45555627
    5 0 loss: 0.42214122
    5 100 loss: 0.49914017
    5 200 loss: 0.42974195
    6 0 loss: 0.4022831
    6 100 loss: 0.4767412
    6 200 loss: 0.4090542
    7 0 loss: 0.38604406
    7 100 loss: 0.45791557
    7 200 loss: 0.39167565
    8 0 loss: 0.3723324
    8 100 loss: 0.44173408
    8 200 loss: 0.37691337
    9 0 loss: 0.360519
    9 100 loss: 0.42779246
    9 200 loss: 0.36422646
    10 0 loss: 0.35006583
    10 100 loss: 0.41538823
    10 200 loss: 0.3530626
    11 0 loss: 0.3407312
    11 100 loss: 0.40423894
    11 200 loss: 0.34306836
    12 0 loss: 0.3323893
    12 100 loss: 0.3939416
    12 200 loss: 0.3339965
    13 0 loss: 0.3248109
    13 100 loss: 0.38446128
    13 200 loss: 0.32582656
    14 0 loss: 0.31788555
    14 100 loss: 0.37571213
    14 200 loss: 0.3183561
    15 0 loss: 0.3113761
    15 100 loss: 0.3676333
    15 200 loss: 0.31151268
    16 0 loss: 0.30531833
    16 100 loss: 0.36009517
    16 200 loss: 0.30516908
    17 0 loss: 0.2996593
    17 100 loss: 0.35302532
    17 200 loss: 0.29931957
    18 0 loss: 0.29437816
    18 100 loss: 0.34642395
    18 200 loss: 0.2938386
    19 0 loss: 0.2894483
    19 100 loss: 0.34028184
    19 200 loss: 0.2887537
    20 0 loss: 0.28483075
    20 100 loss: 0.3345565
    20 200 loss: 0.28399432
    21 0 loss: 0.2804789
    21 100 loss: 0.3291541
    21 200 loss: 0.27953643
    22 0 loss: 0.27633134
    22 100 loss: 0.32407936
    22 200 loss: 0.27533495
    23 0 loss: 0.27240857
    23 100 loss: 0.3192857
    23 200 loss: 0.27136424
    24 0 loss: 0.26872116
    24 100 loss: 0.31474534
    24 200 loss: 0.26758516
    25 0 loss: 0.2652039
    25 100 loss: 0.31041327
    25 200 loss: 0.26399314
    26 0 loss: 0.26185223
    26 100 loss: 0.30627567
    26 200 loss: 0.2605623
    27 0 loss: 0.25865546
    27 100 loss: 0.3023752
    27 200 loss: 0.25727862
    28 0 loss: 0.2556298
    28 100 loss: 0.29863724
    28 200 loss: 0.25413704
    29 0 loss: 0.25273502
    29 100 loss: 0.29504693
    29 200 loss: 0.2511155
  • 相关阅读:
    单片机与嵌入式系统中C语言的位运算小记
    #ifndef、#def、#endif说明
    Freertos学习初识任务函数
    IAR(EWARM)下移植FreeRTOS到STM32F10x笔记
    visio 画 弯曲 箭头 ( 波浪线 曲线)
    dos 中tree的使用方法
    Win7下Borland C++ 4.5 & TASM5.0调试uC/OSII
    (*(volatile unsigned long *)
    有关推挽输出、开漏输出、复用开漏输出、复用推挽输出以及上拉输入、下拉输入、浮空输入、模拟输入区别
    POJ 1236 Network of Schools
  • 原文地址:https://www.cnblogs.com/a155-/p/14279285.html
Copyright © 2011-2022 走看看