zoukankan      html  css  js  c++  java
  • 多层感知机MLP (tensorflow 2.1)

    和svm差不多,感知机应该算svm的前身吧

    import tensorflow as tf
    import numpy as np
    
    class MNISTLoader():
        def __init__(self):
            mnist = tf.keras.datasets.mnist
            (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
            # MNIST中的图像默认为uint8(0-255的数字)。以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道
            self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
            self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
            self.train_label = self.train_label.astype(np.int32)    # [60000]
            self.test_label = self.test_label.astype(np.int32)      # [10000]
            self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]
    
        def get_batch(self, batch_size):
            # 从数据集中随机取出batch_size个元素并返回
            index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
            return self.train_data[index, :], self.train_label[index]
    
    # tf.keras.layers.Dense(
    #     units, activation=None, use_bias=True, kernel_initializer='glorot_uniform',
    #     bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
    #     activity_regularizer=None, kernel_constraint=None, bias_constraint=None,
    #     **kwargs
    # )
    # units: Positive integer, dimensionality of the output space.
    # activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).
    
    class MLP(tf.keras.Model):
        def __init__(self):
            super().__init__()
            self.flatten = tf.keras.layers.Flatten()    # Flatten层将除第一维(batch_size)以外的维度展平
            self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
            self.dense2 = tf.keras.layers.Dense(units=10)
    
        def call(self, inputs):         # [batch_size, 28, 28, 1]
            x = self.flatten(inputs)    # [batch_size, 784]
            x = self.dense1(x)          # [batch_size, 100]
            x = self.dense2(x)          # [batch_size, 10]
            output = tf.nn.softmax(x)
            return output
    
    num_epochs = 5
    batch_size = 50
    learning_rate = 0.001
    
    model = MLP()
    data_loader = MNISTLoader()
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    
    num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
    for batch_index in range(num_batches):
        X, y = data_loader.get_batch(batch_size) #从数据集中随机选一部分数据
        with tf.GradientTape() as tape:
            y_pred = model(X) #得到预测值
            loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)#计算loss
            loss = tf.reduce_mean(loss)
            print("batch %d: loss %f" % (batch_index, loss.numpy()))
        grads = tape.gradient(loss, model.variables) #calc grads
        optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables)) #update grads
    
    #tf.keras.metrics.SparseCategoricalAccuracy是一个评估器
    sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    num_batches = int(data_loader.num_test_data // batch_size)
    for batch_index in range(num_batches):
        start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
        #model.predict 输入测试数据,输出预测结果
        y_pred = model.predict(data_loader.test_data[start_index: end_index])
        sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
    print("test accuracy: %f" % sparse_categorical_accuracy.result())

    参考链接:https://tf.wiki/zh/basic/models.html(这本书挺好的,实践力很强)

  • 相关阅读:
    js常见函数使用
    js数组与函数
    移动端响应式布局
    移动开发之rem布局
    移动flex布局
    移动流式布局
    [剑指offer] 矩阵覆盖
    [剑指offer] 变态跳台阶
    [剑指offer] 跳台阶
    [剑指offer] 斐波那契数列
  • 原文地址:https://www.cnblogs.com/lalalatianlalu/p/12498059.html
Copyright © 2011-2022 走看看