zoukankan      html  css  js  c++  java
  • TensorFlow2中Keras模型保存与加载

    主要记录在Tensorflow2中使用Keras API接口,有关模型保存、加载的内容;

    0. 加载数据、构建网络

    首先,为了方便后续有关模型保存、加载相关代码的正常执行,这里加载mnist数据集、构建一个简单的网络结构。

    import tensorflow as tf
    from libs.load_keras_dataset import load_mnist
    

    注意:下面引入mnist数据集的方式,仅为了方便作者从本地加载、使用;

    mnist_path = '/home/chenz/data/mnist/mnist.npz'
    (x_train, y_train), (x_test, y_test) = load_mnist(data_path=mnist_path)
    print("[INFO] x_train: {}, y_train: {}, x_test: {}, y_test: {}".format(
        x_train.shape, y_train.shape, x_test.shape, y_test.shape
    ))
    train_labels = y_train[:1000]
    test_labels = y_test[:1000]
    train_images = x_train[:1000].reshape(-1, 28*28) / 255.0
    test_images = x_test[:1000].reshape(-1, 28*28) / 255.0
    
    print("[INFO] train_images: {}, train_labels: {}, test_images: {}, test_labels: {}".format(
        train_images.shape, train_labels.shape, test_images.shape, test_labels.shape
    ))
    
    [INFO] x_train: (60000, 28, 28), y_train: (60000,), x_test: (10000, 28, 28), y_test: (10000,)
    [INFO] train_images: (1000, 784), train_labels: (1000,), test_images: (1000, 784), test_labels: (1000,)
    
    

    定义一个方法,用于构建网络结构,并定义网络编译方式,方便后续使用;

    # Build Model
    def create_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10)
        ])
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.1, beta_2=0.2, amsgrad=True),
                      loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=[tf.metrics.SparseCategoricalAccuracy()])
        return model
    

    1. model.save() & model.save_weights()

    在TensorFlow的Keras API中提供了两种保存模型的方式,分别为model.save()model.save_weights(),从字面上可以简单理解,后者仅保存网络结构权重,前者能够保存整个模型结构

    进一步,从源码文档中可以理清两者的区别:

    1.1 model.save()

    该方法能够将整个模型进行保存,以两种方式存储,Tensorflow SavedModelHDF file,保存的文件包括:

    • 模型结构,能够重新实例化模型;
    • 模型权重;
    • 优化器的状态,在上次中断的地方继续训练;

    可以通过tf.keras.models.load_model重新实例化保存的模型,通过该方法返回的模型是已经编译过的模型,除非在之前保存模型的时候就没有被编译;

    利用SequentialFunctional两种形式构建的网络都能够保存成HDF5和SavedModel格式,但是Subclasses形式的模型仅能够保存成SavedModel格式;

    # HDF5格式
    model_name.h5
    
    # Tensorflow SavedModel格式
    ./saved_model
    	assets/
    	saved_model.pb
    	variables/
    

    使用参数说明:

    def save(self,
               filepath,
               overwrite=True,
               include_optimizer=True,
               save_format=None,
               signatures=None,
               options=None):
    
    • filepath表示模型存储的路径;

    • save_format表示以tf或者h5形式进行存储,在TF2中默认tf,TF1中默认h5

    • overwrite表示是否覆盖在目标目录下的已有文件;

    • include_optimizer表示是否保存优化器的状态;

    • signatures仅用于tf形式,具体使用见tf.saved_model.save

    filepathsave_format结合在一起使用,有如下组合方式:

    • filepath.h5为结尾的文件名,则不论save_formattf或者h5,则模型将保存成filename.h5形式;(上级目录需要存在)
    • filepath仅指定文件名,save_format='h5',则模型将保存成filename的HDF形式;
    • filepath指定路径(需存在),save_format='tf',则模型将以Tensorflow SavedModel形式保存到指定路径下;

    注意:filepath不包含后缀时,注意区分是文件目录还是文件名,以tf形式保存,则需要存在指定路径,以h5形式保存,则不能存在相同名称路径;

    1.2 model.save_weights()

    该方法仅保存网络中所有层的权重,

    # HDF5格式
    weights_2 or weights_3.h5
    
    # Tensorflow 格式
    checkpoint 
    weiths_1.data-00000-of-00001
    weigths_1.index
    

    使用参数说明:

    def save_weights(self,
                     filepath,
                     overwrite=True,
                     save_format=None,
                     options=None):
    
    • filepath表示存储的模型文件名或路径;
    • save_format用于表示存储格式,HDF5或者Tensorflow格式;

    filepathsave_format结合使用:

    • filepath以后缀.h5或者.keras结尾,设置save_format=None或者save_format=None,模型将保存成filename.h5filename.keras格式;
    • filepath不含后缀,如果save_format='h5',则模型保存成filename
    • filepath不含后缀,如果save_format='tf'或者save_format=None,则模型保存成Tensorflow格式;

    2. tf.keras.callbacks.ModelCheckpoint

    该方法以回调函数的形式,在模型训练过程中保存模型。

    def __init__(self,
                 filepath,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 save_freq='epoch',
                 options=None,
                 **kwargs):
    

    这里仅提及一点,就是在使用参数save_weights_only时:

    • 设置True,则调用model.save_weights()
    • 设置False,则调用model.save()

    使用方式:

    checkpoint_path = "./saved_model/save_and_load/cp_test_1/cp.ckpt"
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=False,
                                                     verbose=1)
    model.fit(train_images, train_labels,
              epochs=3,
              validation_data=(test_images, test_labels),
              callbacks=[cp_callback])
    

    3. tf.keras.models.load_model、model.load_weights

    上面简单说明了模型保存的两种方式,一种是保存整个模型,另一种则是仅保存模型权重;

    完整的模型可以使用tf.keras.models.load_model加载,只包含权重的模型则使用model.load_weights加载;

    3.1 tf.keras.models.load_model

    加载完整模型

    model_path = './saved_model/save_and_load/save_test/test_5/'
    
    model = tf.keras.models.load_model(model_path)
    model.summary()
    
    • 其中,model_path可以为.h5文件的路径,或者Tensorflow SavedModel的路径

    3.2 model.load_weights

    在重新构建网络的基础上,加载模型权重;

    model = create_model()
    model.load_weights("./saved_model/save_and_load/save_test/weights/weights_1")
    model.summary()
    

    4. 总结

    • 官方API是推荐Tensorflow格式进行保存模型,不论是保存整个模型,或是仅保存权重;
  • 相关阅读:
    shared_ptr weak_ptr boost 内存管理
    _vimrc win7 gvim
    qt 拖放
    数学小魔术 斐波那契数列
    qt4 程序 移植到 qt5
    (转)字符串匹配算法总结
    c++11
    BM 字符串匹配
    编译qt5 demo
    c++ 类库 学习资源
  • 原文地址:https://www.cnblogs.com/chenzhen0530/p/13943172.html
Copyright © 2011-2022 走看看