import os from tensorflow.keras.datasets import mnist import tensorflow as tf from tensorflow.python.keras import Model from tensorflow.python.keras.layers import Flatten, Dense (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train/255.0, x_test/255.0 checkpoint_save_path = './checkpoint/model.ckpt' # 搭建模型类 class MnistModel(Model): def __init__(self): super(MnistModel, self).__init__() self.flatten = Flatten() self.dense1 = Dense(128, activation='relu') self.dense2 = Dense(10, activation='softmax') def call(self, x): x = self.flatten(x) x = self.dense1(x) y = self.dense2(x) return y model = MnistModel() # 模型优化 model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['sparse_categorical_accuracy']) # callback保存模型 model_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) # 曾经保存过,直接加载权重参数 if os.path.exists(checkpoint_save_path + '.index'): model.load_weights(checkpoint_save_path) # 开始训练 model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback]) # 结果总览 model.summary() # 保存模型参数到文本,方便查看 # with open('./weight.txt', 'w') as f: # for i in model.trainable_variables: # f.write(str(i.name) + '\n') # f.write(str(i.shape) + '\n') # # f.write(str(i.numpy()) + '\n') # 这行有问题