zoukankan      html  css  js  c++  java
  • TensorFlow学习笔记Mnist全连接模型实践

    import os
    from tensorflow.keras.datasets import mnist
    import tensorflow as tf
    from tensorflow.python.keras import Model
    from tensorflow.python.keras.layers import Flatten, Dense
    
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    x_train, x_test = x_train/255.0, x_test/255.0
    
    checkpoint_save_path = './checkpoint/model.ckpt'
    
    
    # 搭建模型类
    class MnistModel(Model):
        def __init__(self):
            super(MnistModel, self).__init__()
            self.flatten = Flatten()
            self.dense1 = Dense(128, activation='relu')
            self.dense2 = Dense(10, activation='softmax')
    
        def call(self, x):
            x = self.flatten(x)
            x = self.dense1(x)
            y = self.dense2(x)
            return y
    
    
    model = MnistModel()
    
    # 模型优化
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.sparse_categorical_crossentropy,
                  metrics=['sparse_categorical_accuracy'])
    
    # callback保存模型
    model_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True,
                                                        save_best_only=True)
    
    # 曾经保存过,直接加载权重参数
    if os.path.exists(checkpoint_save_path + '.index'):
        model.load_weights(checkpoint_save_path)
    
    # 开始训练
    model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback])
    
    # 结果总览
    model.summary()
    
    # 保存模型参数到文本,方便查看
    # with open('./weight.txt', 'w') as f:
    #     for i in model.trainable_variables:
    #         f.write(str(i.name) + '\n')
    #         f.write(str(i.shape) + '\n')
    #         # f.write(str(i.numpy()) + '\n')   # 这行有问题
  • 相关阅读:
    记一次事件:由于资源管理器没有关闭所导致数据库挂起
    脚本恢复控制文件
    数据库恢复至某个时间点
    EXPDP/IMPDP
    导入与导出详解
    ORACLE DIRECTORY目录管理步骤
    Linux and Oracle常用目录详解
    omitting directory何意
    在RAC执行相关操作发生ora-01031:insufficient privileges解决方法
    MySQL8.0安装
  • 原文地址:https://www.cnblogs.com/yiduobaozhiblog1/p/15672644.html
Copyright © 2011-2022 走看看