zoukankan      html  css  js  c++  java
  • 【tensorflow】神经网络:断点续训

    断点续训,即在一次训练结束后,可以先将得到的最优训练参数保存起来,待到下次训练时,直接读取最优参数,在此基础上继续训练。

    读取模型参数:

    存储模型参数的文件格式为 ckpt(checkpoint)。

    生成 ckpt 文件时,会同步生成索引表,所以可通过判断是否存在索引表来判断是否存在模型参数。

    # 模型参数保存路径
    checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"  
    if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path)

    保存模型参数:

    # 定义回调函数,在模型训练时,回调函数会被执行,完成保留参数操作
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
      # 文件保存路径
      filepath=checkpoint_save_path,
    
      # 是否只保留模型参数
      save_weights_only=True,
    
      # 是否只保留最优结果
      save_best_only=True
    )
    
    # 执行训练过程,保存新的训练参数
    history = model.fit(x_train, y_train,
                batch_size=32, epochs=5,
                validation_data=(x_test, y_test),
                validation_freq=1,
                callbacks=[cp_callback])

    代码:

    import tensorflow as tf
    import os
    
    # 读取输入特征和标签
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    # 数据归一化,减小计算量,方便神经网络吸收
    x_train, x_test = x_train/255.0, x_test/255.0
    
    # 声明网络结构
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(10, activation="softmax")
    ])
    
    # 配置训练方法
    model.compile(optimizer="adam",
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=[tf.keras.metrics.sparse_categorical_accuracy])
    
    # 如果存在参数文件,直接读取,在此基础上继续训练
    checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"  # 模型参数保存路径
    if os.path.exists(checkpoint_save_path + ".index"):
        model.load_weights(checkpoint_save_path)
    
    # 定义回调函数,在模型训练时,完成保留参数操作
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=True)
    
    # 执行训练过程,保存新的训练参数
    history = model.fit(x_train, y_train,
                        batch_size=32, epochs=5,
                        validation_data=(x_test, y_test),
                        validation_freq=1,
                        callbacks=[cp_callback])
    
    # 打印网络结构和参数
    model.summary()
  • 相关阅读:
    selenium操控浏览器
    DOM
    bug记录
    log日志
    linux 搭建 telnet + tftp
    linux 搭建 MeepoPS+Socket
    php常见面试题(2)
    php常见面试题(1)
    laravel 5 支付宝支付教程
    计算机进位制原理
  • 原文地址:https://www.cnblogs.com/bjxqmy/p/13538364.html
Copyright © 2011-2022 走看看