zoukankan      html  css  js  c++  java
  • [机器学习笔记(三)]保存加载模型的几种方式

    模型的保存和加载

    训练一个相对复杂的模型很有可能需要一段时间,如果是在专门的服务器或计算资源上进行训练那放那里跑就行了。但是如果是在自己的小电脑上跑,就干等着,就可能这段时间电脑都用不了。万一期间要做个其他实验,或者单纯打个游戏放松下就难受了。

    好在TensorFlow提供了训练期间和训练后对模型保存的方法。也就是说,你可以随时暂停一下,然后随时恢复继续训练,甚至别人训练了一半,你可以拿它们的权重继续训练,避免耗费长时间在训练上。

    保存TensorFlow模型有许多不同的方法,取决于你使用的API。因为tf2.0内置了keras这里基本使用keras的方式来完成模型的保存和加载。


    • windows10
    • python==3.6
    • tensorflow==2.0

    准备模型

    创建一个简单的MNIST手写数字识别模型

    获取数据集

    为了训练速度,可以只使用前1000个样本来进行训练和测试

    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
    
    train_labels = train_labels[:1000]
    test_labels = test_labels[:1000]
    
    train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
    test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
    

    定义模型

    这里就直接FC结构不用卷积了

    # Define a simple sequential model
    def create_model():
      model = tf.keras.models.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation='softmax')
      ])
    
      model.compile(optimizer='adam',
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])
    
      return model
    
    # Create a basic model instance
    model = create_model()
    
    # Display the model's architecture
    model.summary()
    

    可以看到如下输出

    Model: "sequential"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense (Dense)                (None, 512)               401920    
    _________________________________________________________________
    dropout (Dropout)            (None, 512)               0         
    _________________________________________________________________
    dense_1 (Dense)              (None, 10)                5130      
    =================================================================
    Total params: 407,050
    Trainable params: 407,050
    Non-trainable params: 0
    _________________________________________________________________
    

    使用回调保存模型

    tf.keras.callbacks.ModelCheckpoint回调可以让我们在训练期间和训练结束时保存模型

    创建回调实例

    checkpoint_path = "training_1/cp.ckpt" # your path to save model
    # checkpoint_dir = os.path.dirname(checkpoint_path) # training_1
    
    # Create a callback that saves the model's weights
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1)
    
    # Train the model with the new callback
    model.fit(train_images, 
              train_labels,  
              epochs=10,
              validation_data=(test_images,test_labels),
              callbacks=[cp_callback])  # Pass callback to training
    

    这个回调会在每个epoch后保存(会覆盖旧的)checkpoint。

    使用保存的模型

    因为我们只保存了模型的权值,所以在恢复的时候需要先创建一个同样结构的模型,然后载入我们保存的权值。

    # Create a basic model instance
    model = create_model()
    
    # Evaluate the model
    loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
    print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
    

    我们还没有对模型进行训练,所以现在权值是全随机的,Accuracy也不会很好。

    1000/1 - 0s - loss: 2.3598 - accuracy: 0.1360
    Untrained model, accuracy: 13.60%
    

    接下来从保存的checkpoint中加载权值再测试

    # Loads the weights
    model.load_weights(checkpoint_path)
    
    # Re-evaluate the model
    loss,acc = model.evaluate(test_images,  test_labels, verbose=2)
    print("Restored model, accuracy: {:5.2f}%".format(100*acc))
    

    之前保存的权值被加载到新创建的模型中,Accuracy变成了之前训练保存的。

    1000/1 - 0s - loss: 0.3937 - accuracy: 0.8650
    Restored model, accuracy: 86.50%
    

    当然tf.keras.callbacks.ModelCheckpoint有很多参数可以设置,比如保存的间隔,或者你可以让他不覆盖,依次保存为不同名字的模型,这样可以在多个checkpoint中选择没有过拟合的

    手动保存模型

    手动保存模型的方式很简单,使用tf.keras.Model.save_weights方法即可。当然前提是你的模型是一个tf.keras.Model类型。

    # Save the weights
    model.save_weights('./checkpoints/my_checkpoint')
    
    # Create a new model instance
    model = create_model()
    
    # Restore the weights
    model.load_weights('./checkpoints/my_checkpoint')
    

    保存整个模型

    之前我们保存的一直都是模型的权值,而没有保存模型的结构,所以我们恢复的时候需要创建一个相同结构的模型,然后使用tf.keras.Model.load_weights来加载权重。

    保存整个模型可以让我们使用时不需要手动写模型结构,直接从保存的模型里提取出模型。

    安装依赖

    如果要使用HDF5格式来保存整个模型需要安装如下依赖

    pip install -q pyyaml h5py
    

    HDF5格式

    保存为HDF5格式只需要使用tf.keras.Model.save(),且文件名后缀为.h5即可

    # Create and train a new model instance.
    model = create_model()
    model.fit(train_images, train_labels, epochs=5)
    
    # Save the entire model to a HDF5 file.
    # The '.h5' extension indicates that the model shuold be saved to HDF5.
    model.save('my_model.h5') 
    # Or give a save_format arg
    # model.save(filepath="", save_format="h5")
    

    SavedModel格式

    tf.keras.Model.save()方法也可以用于保存SavedModel格式(这个格式是tf的格式,h5则是一个通用的格式)

    # Create and train a new model instance.
    model = create_model()
    model.fit(train_images, train_labels, epochs=5)
    
    # Save the entire model as a SavedModel.
    !mkdir -p saved_model
    model.save('saved_model/my_model') 
    # Or give a save_format arg
    #model.save('saved_model/my_model', save_format='tf')
    

    通过tf.keras.Model.save()的参数save_format可以调整保存的格式,tf2.0以下可能默认保存的是h5格式,tf2.0改为了tf格式。

    加载模型

    加载的时候就不需要先定义一个相同结构的模型,直接从保存的模型中提取出模型结构和权值

    # Recreate the exact same model, including its weights and the optimizer
    new_modelh5 = tf.keras.models.load_model('my_model.h5')
    new_modeltf = tf.keras.models.load_model('saved_model/my_model')
    
    # Show the model architecture
    new_modelh5.summary()
    new_modeltf.summary()
    

    tf中tf.keras.models.load_model可以恢复两种格式保存的模型。

    保存自定义的模型对象

    tf中SavedModelHDF5一大区别是,SavedModel可以自动保存自定义的模型对象(有自定义层的模型),而HDF5需要模型配置来保存自定义模型。

    参考

    [1] https://tensorflow.google.cn/tutorials/keras/save_and_load

  • 相关阅读:
    Codeforces 845E Fire in the City 线段树
    Codeforces 542D Superhero's Job dp (看题解)
    Codeforces 797F Mice and Holes dp
    Codeforces 408D Parcels dp (看题解)
    Codeforces 464D World of Darkraft
    Codeforces 215E Periodical Numbers 容斥原理
    Codeforces 285E Positions in Permutations dp + 容斥原理
    Codeforces 875E Delivery Club dp
    Codeforces 888F Connecting Vertices 区间dp (看题解)
    Codeforces 946F Fibonacci String Subsequences dp (看题解)
  • 原文地址:https://www.cnblogs.com/Axi8/p/11810469.html
Copyright © 2011-2022 走看看