zoukankan      html  css  js  c++  java
  • tensorflow2.0——手写数据集预测(全连接神经3层网络)

    import tensorflow as tf
    import numpy as np
    from tensorflow.keras import datasets, layers, optimizers
    
    
    # 加载手写数字数据
    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    
    xs = tf.convert_to_tensor(train_x, dtype=tf.float32)/255                    #   除255将像素点值变为0-1的值
    ys = tf.convert_to_tensor(train_y.reshape(-1, 1), dtype=tf.float32)
    db = tf.data.Dataset.from_tensor_slices((xs, ys)).batch(200)                #   将标记值和样本封装为元组,且每次以200个样本作为求梯度整体
    
    #   设置超参
    iter = 100
    learn_rate = 0.01
    #   定义模型和优化器
    model = tf.keras.Sequential([
        layers.Dense(512, activation='relu'),
        layers.Dense(256, activation='relu'),           #   全连接
        layers.Dense(10)
    ])
    optimizer = optimizers.SGD(learning_rate=learn_rate)            #   优化器
    
    #   迭代代码
    for i in range(iter):
        print('i:',i)
        for step,(x,y) in enumerate(db):                            #   对每个batch样本做梯度计算
            #   将标记值转化为one-hot编码
            y_hot = np.zeros((y.shape[0], 10))
            for row_index in range(y.shape[0]):
                # print('这是i:{}, step:{} :'.format(i,step))
                y_hot[row_index][int(y[row_index].numpy()[0])] = 1
    
            with tf.GradientTape() as tape:
                x = tf.reshape(x,(-1,28*28))               #   将28*28展开为784
                out = model(x)
                loss = tf.reduce_mean(tf.square(out-y_hot))
            grads = tape.gradient(loss,model.trainable_variables)               #   求梯度
            optimizer.apply_gradients(zip(grads,model.trainable_variables))     #   优化器进行参数优化
            if step % 100 == 0:
                print('i:{} ,step:{} ,loss:{}'.format(i, step,loss.numpy()))
                #   求准确率
                acc = tf.equal(tf.argmax(out,axis=1),tf.argmax(y_hot,axis=1))
                acc = tf.cast(acc,tf.int8)
                acc = tf.reduce_mean(tf.cast(acc,tf.float32))
                print('acc:',acc.numpy())
  • 相关阅读:
    洛谷 P2958 [USACO09OCT]木瓜的丛林Papaya Jungle
    洛谷 P1400 塔
    10-2 集合之List
    主从数据库
    【单元测试】
    Pen Editor
    appendGrid
    动画
    JavaScript框架设计 第14章 动画引擎
    >>>
  • 原文地址:https://www.cnblogs.com/cxhzy/p/13463910.html
Copyright © 2011-2022 走看看