zoukankan      html  css  js  c++  java
  • 使用八股搭建手写数据集神经网络

    写在前面

    今天是初五,好好的玩了几天后还是回归到了学习的正轨上。今天主要学习了神经网络的搭建八股,使用这种模型搭建了一个训练手写数据集的神经网络

    搭建网络八股

    六步法:

    import

    train,test

    model = tf.keras.models.Sequential

    model.compile

    model.fit

    model.summary

    总的来说,首先导包,然后指定出训练集和测试集。使用tensorflow提供的API搭建好每层神经网络结构,进行compile,指定优化器损失函数和衡量标准。使用fit函数来训练神经网络,最后使用summary来输出训练结果。

    训练手写数据集

    先来看代码:

    import tensorflow as tf
    
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    
    model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
    
    model.summary()
    

    代码不长,我们是严格按照六步法来搭建神经网络,可以看到十分简单。核心部分就是指定神经网络结构。

    总结

    总的来说,使用这种方法搭建神经网络还是十分简单的,但其中的原理一定要好好理解清楚。

  • 相关阅读:
    【HDOJ】4370 0 or 1
    【HDOJ】4122 Alice's mooncake shop
    【HDOJ】4393 Throw nails
    【HDOJ】2385 Stock
    WinCE 输入法编程
    WinCE 自由拼音输入法的测试
    SQL Server CE开发环境建立过程
    【SQL Server CE2.0】创建加密的数据库(源代码)
    【SQL Server CE2.0】打开加密的数据库(源代码)
    EVC在双核PC上调试速度慢的原因
  • 原文地址:https://www.cnblogs.com/wushenjiang/p/14407233.html
Copyright © 2011-2022 走看看