zoukankan      html  css  js  c++  java
  • TensorFlow2中Keras模型保存与加载

    主要记录在Tensorflow2中使用Keras API接口,有关模型保存、加载的内容;

    0. 加载数据、构建网络

    首先,为了方便后续有关模型保存、加载相关代码的正常执行,这里加载mnist数据集、构建一个简单的网络结构。

    import tensorflow as tf
    from libs.load_keras_dataset import load_mnist
    

    注意:下面引入mnist数据集的方式,仅为了方便作者从本地加载、使用;

    mnist_path = '/home/chenz/data/mnist/mnist.npz'
    (x_train, y_train), (x_test, y_test) = load_mnist(data_path=mnist_path)
    print("[INFO] x_train: {}, y_train: {}, x_test: {}, y_test: {}".format(
        x_train.shape, y_train.shape, x_test.shape, y_test.shape
    ))
    train_labels = y_train[:1000]
    test_labels = y_test[:1000]
    train_images = x_train[:1000].reshape(-1, 28*28) / 255.0
    test_images = x_test[:1000].reshape(-1, 28*28) / 255.0
    
    print("[INFO] train_images: {}, train_labels: {}, test_images: {}, test_labels: {}".format(
        train_images.shape, train_labels.shape, test_images.shape, test_labels.shape
    ))
    
    [INFO] x_train: (60000, 28, 28), y_train: (60000,), x_test: (10000, 28, 28), y_test: (10000,)
    [INFO] train_images: (1000, 784), train_labels: (1000,), test_images: (1000, 784), test_labels: (1000,)
    
    

    定义一个方法,用于构建网络结构,并定义网络编译方式,方便后续使用;

    # Build Model
    def create_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10)
        ])
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.1, beta_2=0.2, amsgrad=True),
                      loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=[tf.metrics.SparseCategoricalAccuracy()])
        return model
    

    1. model.save() & model.save_weights()

    在TensorFlow的Keras API中提供了两种保存模型的方式,分别为model.save()model.save_weights(),从字面上可以简单理解,后者仅保存网络结构权重,前者能够保存整个模型结构

    进一步,从源码文档中可以理清两者的区别:

    1.1 model.save()

    该方法能够将整个模型进行保存,以两种方式存储,Tensorflow SavedModelHDF file,保存的文件包括:

    • 模型结构,能够重新实例化模型;
    • 模型权重;
    • 优化器的状态,在上次中断的地方继续训练;

    可以通过tf.keras.models.load_model重新实例化保存的模型,通过该方法返回的模型是已经编译过的模型,除非在之前保存模型的时候就没有被编译;

    利用SequentialFunctional两种形式构建的网络都能够保存成HDF5和SavedModel格式,但是Subclasses形式的模型仅能够保存成SavedModel格式;

    # HDF5格式
    model_name.h5
    
    # Tensorflow SavedModel格式
    ./saved_model
    	assets/
    	saved_model.pb
    	variables/
    

    使用参数说明:

    def save(self,
               filepath,
               overwrite=True,
               include_optimizer=True,
               save_format=None,
               signatures=None,
               options=None):
    
    • filepath表示模型存储的路径;

    • save_format表示以tf或者h5形式进行存储,在TF2中默认tf,TF1中默认h5

    • overwrite表示是否覆盖在目标目录下的已有文件;

    • include_optimizer表示是否保存优化器的状态;

    • signatures仅用于tf形式,具体使用见tf.saved_model.save

    filepathsave_format结合在一起使用,有如下组合方式:

    • filepath.h5为结尾的文件名,则不论save_formattf或者h5,则模型将保存成filename.h5形式;(上级目录需要存在)
    • filepath仅指定文件名,save_format='h5',则模型将保存成filename的HDF形式;
    • filepath指定路径(需存在),save_format='tf',则模型将以Tensorflow SavedModel形式保存到指定路径下;

    注意:filepath不包含后缀时,注意区分是文件目录还是文件名,以tf形式保存,则需要存在指定路径,以h5形式保存,则不能存在相同名称路径;

    1.2 model.save_weights()

    该方法仅保存网络中所有层的权重,

    # HDF5格式
    weights_2 or weights_3.h5
    
    # Tensorflow 格式
    checkpoint 
    weiths_1.data-00000-of-00001
    weigths_1.index
    

    使用参数说明:

    def save_weights(self,
                     filepath,
                     overwrite=True,
                     save_format=None,
                     options=None):
    
    • filepath表示存储的模型文件名或路径;
    • save_format用于表示存储格式,HDF5或者Tensorflow格式;

    filepathsave_format结合使用:

    • filepath以后缀.h5或者.keras结尾,设置save_format=None或者save_format=None,模型将保存成filename.h5filename.keras格式;
    • filepath不含后缀,如果save_format='h5',则模型保存成filename
    • filepath不含后缀,如果save_format='tf'或者save_format=None,则模型保存成Tensorflow格式;

    2. tf.keras.callbacks.ModelCheckpoint

    该方法以回调函数的形式,在模型训练过程中保存模型。

    def __init__(self,
                 filepath,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 save_freq='epoch',
                 options=None,
                 **kwargs):
    

    这里仅提及一点,就是在使用参数save_weights_only时:

    • 设置True,则调用model.save_weights()
    • 设置False,则调用model.save()

    使用方式:

    checkpoint_path = "./saved_model/save_and_load/cp_test_1/cp.ckpt"
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=False,
                                                     verbose=1)
    model.fit(train_images, train_labels,
              epochs=3,
              validation_data=(test_images, test_labels),
              callbacks=[cp_callback])
    

    3. tf.keras.models.load_model、model.load_weights

    上面简单说明了模型保存的两种方式,一种是保存整个模型,另一种则是仅保存模型权重;

    完整的模型可以使用tf.keras.models.load_model加载,只包含权重的模型则使用model.load_weights加载;

    3.1 tf.keras.models.load_model

    加载完整模型

    model_path = './saved_model/save_and_load/save_test/test_5/'
    
    model = tf.keras.models.load_model(model_path)
    model.summary()
    
    • 其中,model_path可以为.h5文件的路径,或者Tensorflow SavedModel的路径

    3.2 model.load_weights

    在重新构建网络的基础上,加载模型权重;

    model = create_model()
    model.load_weights("./saved_model/save_and_load/save_test/weights/weights_1")
    model.summary()
    

    4. 总结

    • 官方API是推荐Tensorflow格式进行保存模型,不论是保存整个模型,或是仅保存权重;
  • 相关阅读:
    PAT 1088. Rational Arithmetic
    PAT 1087. All Roads Lead to Rome
    PAT 1086. Tree Traversals Again
    PAT 1085. Perfect Sequence
    PAT 1084. Broken Keyboard
    PAT 1083. List Grades
    PAT 1082. Read Number in Chinese
    求最大公因数
    [转载]Latex文件转成pdf后的字体嵌入问题的解决
    [转载]Matlab有用的小工具小技巧
  • 原文地址:https://www.cnblogs.com/chenzhen0530/p/13943172.html
Copyright © 2011-2022 走看看