zoukankan      html  css  js  c++  java
  • AI

    保存和恢复模型(Save and restore models)

    官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_models

    在训练期间保存检查点

    在训练期间或训练结束时自动保存检查点。
    权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。
    可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断

    • 检查点回调用法:创建检查点回调,训练模型并将ModelCheckpoint回调传递给该模型,得到检查点文件集合,用于分享权重
    • 检查点回调选项:该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。

    手动保存权重

    使用 Model.save_weights 方法即可手动保存权重

    保存整个模型

    整个模型可以保存到一个文件中,其中包含权重值、模型配置(架构)、优化器配置。
    可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
    Keras通过检查架构来保存模型,使用HDF5标准提供基本的保存格式。
    特别注意:

    • 目前无法保存TensorFlow优化器(来自tf.train)。
    • 使用此类优化器时,需要在加载模型后对其进行重新编译,使优化器的状态变松散。

    MNIST数据集

    MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集

    示例

    脚本内容

    GitHub:https://github.com/anliven/Hello-AI/blob/master/Google-Learn-and-use-ML/5_save_and_restore_models.py

      1 # coding=utf-8
      2 import tensorflow as tf
      3 from tensorflow import keras
      4 import numpy as np
      5 import pathlib
      6 import os
      7 
      8 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
      9 print("# TensorFlow version: {}  - tf.keras version: {}".format(tf.VERSION, tf.keras.__version__))  # 查看版本
     10 
     11 # ### 获取示例数据集
     12 
     13 ds_path = str(pathlib.Path.cwd()) + "\datasets\mnist\"  # 数据集路径
     14 np_data = np.load(ds_path + "mnist.npz")  # 加载numpy格式数据
     15 print("# np_data keys: ", list(np_data.keys()))  # 查看所有的键
     16 
     17 # 加载mnist数据集
     18 (train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data(path=ds_path + "mnist.npz")
     19 train_labels = train_labels[:1000]
     20 test_labels = test_labels[:1000]
     21 train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
     22 test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
     23 
     24 
     25 # ### 定义模型
     26 def create_model():
     27     model = tf.keras.models.Sequential([
     28         keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
     29         keras.layers.Dropout(0.2),
     30         keras.layers.Dense(10, activation=tf.nn.softmax)
     31     ])  # 构建一个简单的模型
     32     model.compile(optimizer=tf.keras.optimizers.Adam(),
     33                   loss=tf.keras.losses.sparse_categorical_crossentropy,
     34                   metrics=['accuracy'])
     35     return model
     36 
     37 
     38 mod = create_model()
     39 mod.summary()
     40 
     41 # ### 在训练期间保存检查点
     42 
     43 # 检查点回调用法
     44 checkpoint_path = "training_1/cp.ckpt"
     45 checkpoint_dir = os.path.dirname(checkpoint_path)  # 检查点存放目录
     46 cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
     47                                                  save_weights_only=True,
     48                                                  verbose=2)  # 创建检查点回调
     49 model1 = create_model()
     50 model1.fit(train_images, train_labels,
     51            epochs=10,
     52            validation_data=(test_images, test_labels),
     53            verbose=0,
     54            callbacks=[cp_callback]  # 将ModelCheckpoint回调传递给该模型
     55            )  # 训练模型,将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新
     56 
     57 model2 = create_model()  # 创建一个未经训练的全新模型(与原始模型架构相同,才能分享权重)
     58 loss, acc = model2.evaluate(test_images, test_labels)  # 使用测试集进行评估
     59 print("# Untrained model2, accuracy: {:5.2f}%".format(100 * acc))  # 未训练模型的表现(准确率约为10%)
     60 
     61 model2.load_weights(checkpoint_path)  # 从检查点加载权重
     62 loss, acc = model2.evaluate(test_images, test_labels)  # 使用测试集,重新进行评估
     63 print("# Restored model2, accuracy: {:5.2f}%".format(100 * acc))  # 模型表现得到大幅提升
     64 
     65 # 检查点回调选项
     66 checkpoint_path2 = "training_2/cp-{epoch:04d}.ckpt"  # 使用“str.format”方式为每个检查点设置唯一名称
     67 checkpoint_dir2 = os.path.dirname(checkpoint_path)
     68 cp_callback2 = tf.keras.callbacks.ModelCheckpoint(checkpoint_path2,
     69                                                   verbose=1,
     70                                                   save_weights_only=True,
     71                                                   period=5  # 每隔5个周期保存一次检查点
     72                                                   )  # 创建检查点回调
     73 model3 = create_model()
     74 model3.fit(train_images, train_labels,
     75            epochs=50,
     76            callbacks=[cp_callback2],  # 将ModelCheckpoint回调传递给该模型
     77            validation_data=(test_images, test_labels),
     78            verbose=0)  # 训练一个新模型,每隔5个周期保存一次检查点并设置唯一名称
     79 latest = tf.train.latest_checkpoint(checkpoint_dir2)
     80 print("# latest checkpoint: {}".format(latest))  # 查看最新的检查点
     81 
     82 model4 = create_model()  # 重新创建一个全新的模型
     83 loss, acc = model2.evaluate(test_images, test_labels)  # 使用测试集进行评估
     84 print("# Untrained model4, accuracy: {:5.2f}%".format(100 * acc))  # 未训练模型的表现(准确率约为10%)
     85 
     86 model4.load_weights(latest)  # 加载最新的检查点
     87 loss, acc = model4.evaluate(test_images, test_labels)  #
     88 print("# Restored model4, accuracy: {:5.2f}%".format(100 * acc))  # 模型表现得到大幅提升
     89 
     90 # ### 手动保存权重
     91 model5 = create_model()
     92 model5.fit(train_images, train_labels,
     93            epochs=10,
     94            validation_data=(test_images, test_labels),
     95            verbose=0)  # 训练模型
     96 model5.save_weights('./training_3/my_checkpoint')  # 手动保存权重
     97 
     98 model6 = create_model()
     99 loss, acc = model6.evaluate(test_images, test_labels)
    100 print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
    101 model6.load_weights('./training_3/my_checkpoint')
    102 loss, acc = model6.evaluate(test_images, test_labels)
    103 print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
    104 
    105 # ### 保存整个模型
    106 model7 = create_model()
    107 model7.fit(train_images, train_labels, epochs=5)
    108 model7.save('my_model.h5')  # 保存整个模型到HDF5文件
    109 
    110 model8 = keras.models.load_model('my_model.h5')  # 重建完全一样的模型,包括权重和优化器
    111 model8.summary()
    112 loss, acc = model8.evaluate(test_images, test_labels)
    113 print("Restored model8, accuracy: {:5.2f}%".format(100 * acc))

    运行结果

    C:UsersanlivenAppDataLocalcondacondaenvsmlccpython.exe D:/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML/5_save_and_restore_models.py
    # TensorFlow version: 1.12.0  - tf.keras version: 2.1.6-tf
    # np_data keys:  ['x_test', 'x_train', 'y_train', 'y_test']
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense (Dense)                (None, 512)               401920    
    _________________________________________________________________
    dropout (Dropout)            (None, 512)               0         
    _________________________________________________________________
    dense_1 (Dense)              (None, 10)                5130      
    =================================================================
    Total params: 407,050
    Trainable params: 407,050
    Non-trainable params: 0
    _________________________________________________________________
    
    Epoch 00001: saving model to training_1/cp.ckpt
    Epoch 00002: saving model to training_1/cp.ckpt
    Epoch 00003: saving model to training_1/cp.ckpt
    Epoch 00004: saving model to training_1/cp.ckpt
    Epoch 00005: saving model to training_1/cp.ckpt
    Epoch 00006: saving model to training_1/cp.ckpt
    Epoch 00007: saving model to training_1/cp.ckpt
    Epoch 00008: saving model to training_1/cp.ckpt
    Epoch 00009: saving model to training_1/cp.ckpt
    Epoch 00010: saving model to training_1/cp.ckpt
    
      32/1000 [..............................] - ETA: 3s
    1000/1000 [==============================] - 0s 140us/step
    # Untrained model2, accuracy:  8.20%
    
      32/1000 [..............................] - ETA: 0s
    1000/1000 [==============================] - 0s 40us/step
    # Restored model2, accuracy: 86.40%
    
    Epoch 00005: saving model to training_2/cp-0005.ckpt
    Epoch 00010: saving model to training_2/cp-0010.ckpt
    Epoch 00015: saving model to training_2/cp-0015.ckpt
    Epoch 00020: saving model to training_2/cp-0020.ckpt
    Epoch 00025: saving model to training_2/cp-0025.ckpt
    Epoch 00030: saving model to training_2/cp-0030.ckpt
    Epoch 00035: saving model to training_2/cp-0035.ckpt
    Epoch 00040: saving model to training_2/cp-0040.ckpt
    Epoch 00045: saving model to training_2/cp-0045.ckpt
    Epoch 00050: saving model to training_2/cp-0050.ckpt
    
    # latest checkpoint: training_1cp.ckpt
    
      32/1000 [..............................] - ETA: 3s
    1000/1000 [==============================] - 0s 140us/step
    # Untrained model4, accuracy: 86.40%
    
      32/1000 [..............................] - ETA: 2s
    1000/1000 [==============================] - 0s 110us/step
    # Restored model4, accuracy: 86.40%
    
      32/1000 [..............................] - ETA: 5s
    1000/1000 [==============================] - 0s 220us/step
    # Restored model6, accuracy: 18.20%
    
      32/1000 [..............................] - ETA: 0s
    1000/1000 [==============================] - 0s 40us/step
    # Restored model6, accuracy: 87.40%
    Epoch 1/5
    
      32/1000 [..............................] - ETA: 9s - loss: 2.4141 - acc: 0.0625
     320/1000 [========>.....................] - ETA: 0s - loss: 1.8229 - acc: 0.4469
     576/1000 [================>.............] - ETA: 0s - loss: 1.4932 - acc: 0.5694
     864/1000 [========================>.....] - ETA: 0s - loss: 1.2624 - acc: 0.6481
    1000/1000 [==============================] - 1s 530us/step - loss: 1.1978 - acc: 0.6620
    Epoch 2/5
    
      32/1000 [..............................] - ETA: 0s - loss: 0.5490 - acc: 0.8750
     320/1000 [========>.....................] - ETA: 0s - loss: 0.4832 - acc: 0.8594
     576/1000 [================>.............] - ETA: 0s - loss: 0.4630 - acc: 0.8715
     864/1000 [========================>.....] - ETA: 0s - loss: 0.4356 - acc: 0.8808
    1000/1000 [==============================] - 0s 200us/step - loss: 0.4298 - acc: 0.8790
    Epoch 3/5
    
      32/1000 [..............................] - ETA: 0s - loss: 0.1681 - acc: 0.9688
     320/1000 [========>.....................] - ETA: 0s - loss: 0.2826 - acc: 0.9437
     576/1000 [================>.............] - ETA: 0s - loss: 0.2774 - acc: 0.9340
     832/1000 [=======================>......] - ETA: 0s - loss: 0.2740 - acc: 0.9327
    1000/1000 [==============================] - 0s 200us/step - loss: 0.2781 - acc: 0.9280
    Epoch 4/5
    
      32/1000 [..............................] - ETA: 0s - loss: 0.1589 - acc: 0.9688
     288/1000 [=======>......................] - ETA: 0s - loss: 0.2169 - acc: 0.9410
     608/1000 [=================>............] - ETA: 0s - loss: 0.2186 - acc: 0.9457
     864/1000 [========================>.....] - ETA: 0s - loss: 0.2231 - acc: 0.9479
    1000/1000 [==============================] - 0s 200us/step - loss: 0.2164 - acc: 0.9480
    Epoch 5/5
    
      32/1000 [..............................] - ETA: 0s - loss: 0.1095 - acc: 1.0000
     352/1000 [=========>....................] - ETA: 0s - loss: 0.1631 - acc: 0.9744
     608/1000 [=================>............] - ETA: 0s - loss: 0.1671 - acc: 0.9638
     864/1000 [========================>.....] - ETA: 0s - loss: 0.1545 - acc: 0.9688
    1000/1000 [==============================] - 0s 210us/step - loss: 0.1538 - acc: 0.9670
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_14 (Dense)             (None, 512)               401920    
    _________________________________________________________________
    dropout_7 (Dropout)          (None, 512)               0         
    _________________________________________________________________
    dense_15 (Dense)             (None, 10)                5130      
    =================================================================
    Total params: 407,050
    Trainable params: 407,050
    Non-trainable params: 0
    _________________________________________________________________
    
      32/1000 [..............................] - ETA: 3s
    1000/1000 [==============================] - 0s 150us/step
    Restored model8, accuracy: 86.10%
    
    Process finished with exit code 0

    生成的文件

    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $ ll training_1
    total 1601
    -rw-r--r-- 1 anliven 197121      71 5月   5 23:36 checkpoint
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp.ckpt.index
    
    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $
    
    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $
    
    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $
    
    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $ ls -l training_1
    total 1601
    -rw-r--r-- 1 anliven 197121      71 5月   5 23:36 checkpoint
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp.ckpt.index
    
    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $ ls -l training_2
    total 16001
    -rw-r--r-- 1 anliven 197121      81 5月   5 23:37 checkpoint
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0005.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0005.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0010.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0010.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0015.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0015.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0020.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0020.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:36 cp-0025.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:36 cp-0025.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0030.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0030.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0035.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0035.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0040.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0040.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0045.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0045.ckpt.index
    -rw-r--r-- 1 anliven 197121 1631508 5月   5 23:37 cp-0050.ckpt.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:37 cp-0050.ckpt.index
    
    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $ ls -l training_3
    total 1601
    -rw-r--r-- 1 anliven 197121      83 5月   5 23:37 checkpoint
    -rw-r--r-- 1 anliven 197121 1631517 5月   5 23:37 my_checkpoint.data-00000-of-00001
    -rw-r--r-- 1 anliven 197121     647 5月   5 23:37 my_checkpoint.index
    
    anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
    $ ls -l my_model.h5
    -rw-r--r-- 1 anliven 197121 4909112 5月   5 23:37 my_model.h5

    问题处理

    问题描述:出现如下告警信息。

    WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x00000280FD318780>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.
    
    Consider using a TensorFlow optimizer from `tf.train`.

    问题处理:

    正常告警,对脚本运行和结果无影响,暂不关注。

  • 相关阅读:
    java中的位运算符
    Servlet中的初始化参数、上下文参数、以及@Resource资源注入
    Servlet中文件上传的几种方式
    marquee标签的使用
    SpringBoot热部署的两种方式
    eclipse中安装lombok插件
    关于Servlet中的转发和重定项
    Cormen — The Best Friend Of a Man CodeForces 732B
    牛客小白月赛9之签到题
    Codeforces アンバランス / Unbalanced
  • 原文地址:https://www.cnblogs.com/anliven/p/10817233.html
Copyright © 2011-2022 走看看