zoukankan      html  css  js  c++  java
  • Tensorflow2.0多层感知机实现mnist手写数字识别

     1 import tensorflow as tf
     2 import numpy as np
     3 """
     4 tf.keras.datasets 获得数据集 并预处理
     5 """
     6 class MNISTLoader():
     7     def __init__(self):                #加载类的构造函数
     8         mnist = tf.keras.datasets.mnist
     9         (self.train_data,self.train_label),(self.test_data,self.test_label) = mnist.load_data()  #(6000,28,28)
    10         #axis  = 1 ,相当于在最内层每个元素加括号变成  n, 1 ,在后面加个1 这个维度;   axis = 0,就是维度变成  1,n,在中间加个1这个维度
    11         self.train_data = np.expand_dims(self.train_data.astype(np.float32)/255.0,axis = -1)
    12         self.test_data = np.expand_dims(self.test_data.astype(np.float32)/255.0,axis = -1)
    13         self.test_label = self.test_label.astype(np.int32)      # [10000]
    14         self.train_label = self.train_label.astype(np.int32)    # [60000]
    15         self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]
    16     def get_batch(self,betch_size):
    17         index = np.random.randint(0,np.shape(self.train_data)[0],batch_size)      #(low, high, size),size可以是一个元组,那么会生成一个多维数组
    18         #从(0,6000)中随机抽取batch_size个随机整数,数字包括下界,不包括上界
    19         return self.train_data[index, :], self.train_label[index]  #返回抽取的batch数字对应的样本,[[],[],[],[],...],     [1,3,4,5,7,...]    一个batch返回两个tensor
    20 
    21 """
    22 模型构建,tf.keras.Model ,tf.keras.layers
    23 """
    24 class MLP(tf.keras.Model):
    25     def __init__(self):
    26         super().__init__()
    27         self.flatten = tf.keras.layers.Flatten()
    28         self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
    29         self.dense2 = tf.keras.layers.Dense(units=10)
    30 
    31     def call(self, inputs):
    32         x = self.flatten(inputs)
    33         x = self.dense1(x)
    34         x = self.dense2(x)
    35         output = tf.nn.softmax(x)  # 用了多分类激活函数在最后输出层
    36         # softmax,归一化指数函数。且变得可导了。平滑化的 argmax 函数,平滑功能就是使各类的直接数字差别不会特别大了,即soft.
    37         return output
    38 """
    39 模型的训练 tf.keras.losses , tf.keras.optimizer
    40 """
    41 if __name__ =="__main__":
    42     num_epochs = 5
    43     batch_size = 50
    44     lr = 0.001
    45     model = MLP()
    46     data_loader = MNISTLoader()
    47     optimizer = tf.keras.optimizers.Adam(learning_rate = lr)  #优化器就两点 :1.哪种优化算法,2.优化算法里超参的定义,如学习率等
    48 
    49     num_batches = int(data_loader.num_train_data // batch_size * num_epochs)  # //向下取整
    50     for batch_index in range(num_batches):
    51         X, y = data_loader.get_batch(batch_size)
    52         with tf.GradientTape() as tape:
    53             y_pred = model(X)
    54             loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
    55             loss = tf.reduce_mean(loss)  # 就是计算给定tensor的某个轴向上的平均值,此处是loss张量
    56             print("batch %d:loss %f" % (batch_index, loss.numpy()))
    57         grads = tape.gradient(loss, model.variables)
    58         optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))  # 这里是因为参数更新函数接受的是(1,w1),(2,w2),;;;参数对的形式
    59 
    60 """
    61 模型评估 tf.keras.metrics
    62 """
    63     sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()  #实例化一个评估器
    64     num_batches = int(data_loader.num_test_data // batch_size)
    65     for batch_index in range(num_batches):   #测试集也分了batch
    66         start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    67         y_pred = model.predict(data_loader.test_data[start_index: end_index])
    68         sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
    69         #update_state方法记录并处理结果
    70     print("test accuracy: %f" % sparse_categorical_accuracy.result())
  • 相关阅读:
    Flex中States的用法
    MAX脚本翻译教学
    WARN No appenders could be found for logger 解决
    解压版(绿色版)Tomcat配置
    Bootstrap入门
    什么时候用margin、padding
    简易的商品统计
    块级元素&行内元素
    不定宽元素水平居中
    JavaScript与表单交互(表单验证模型)
  • 原文地址:https://www.cnblogs.com/Henry-ZHAO/p/12725248.html
Copyright © 2011-2022 走看看