zoukankan      html  css  js  c++  java
  • TensorFlow——训练模型的保存和载入的方法介绍

    我们在训练好模型的时候,通常是要将模型进行保存的,以便于下次能够直接的将训练好的模型进行载入。

    1.保存模型

    首先需要建立一个saver,然后在session中通过saver的save即可将模型保存起来,具体的代码流程如下

    # 前面的是定义好的模型结构

    # 前面的代码是模型的定义代码
    
    saver = tf.train.Saver()    # 生成saver
     
    with tf.Session() as sess:
        sess.run(init)          # 模型的初始化
        # 
        # 模型的训练代码,当模型训练完毕后,下面就可以对模型进行保存了
        # 
        saver.save(sess, "model/linear")     # 当路径不存在时,会自动创建路径

    2.载入模型

    将模型保存后,在保存的路径中,可以看到生成的模型路径,下面我们就能够加载模型了:

    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        # 可以对模型进行初始化,也可以不进行模型的初始化,因为后面的加载会覆盖之前的
        # 初始化操作
        sess.run(init)
    
        saver.restore(sess, "model/linear")

    下面我们以linearmodel为例进行讲解:

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    
    train_x = np.linspace(-5, 3, 50)
    train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5
    
    plt.plot(train_x, train_y, 'r.')
    plt.grid(True)
    plt.show()
    
    X = tf.placeholder(dtype=tf.float32)
    Y = tf.placeholder(dtype=tf.float32)
    
    w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
    b = tf.Variable(tf.random.truncated_normal([1]), name='bias')
    
    z = tf.multiply(X, w) + b
    
    cost = tf.reduce_mean(tf.square(Y - z))
    learning_rate = 0.01
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
    
    init = tf.global_variables_initializer()
    
    training_epochs = 20
    display_step = 2
    
    
    saver = tf.train.Saver()
    
    
    if __name__ == '__main__':
        with tf.Session() as sess:
            sess.run(init)
            if os.path.exists("model/"):
                saver.restore(sess, "model/linear")
    
                w_, b_ = sess.run([w, b])
    
                print(" Finished ")
                print("W: ", w_, " b: ", b_)
                plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
                plt.grid(True)
                plt.show()
            else:
                loss_list = []
                for epoch in range(training_epochs):
                    for (x, y) in zip(train_x, train_y):
                        sess.run(optimizer, feed_dict={X: x, Y: y})
    
                    if epoch % display_step == 0:
                        loss = sess.run(cost, feed_dict={X: x, Y: y})
                        loss_list.append(loss)
                        print('Iter: ', epoch, ' Loss: ', loss)
    
                w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
    
                saver.save(sess, "model/linear")
    
                print(" Finished ")
                print("W: ", w_, " b: ", b_, " loss: ", loss)
                plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
                plt.grid(True)
                plt.show()

    3.查看模型的内容

    from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
    modeldir = 'model/'
    print_tensors_in_checkpoint_file(modeldir + 'linear.cpkt', None, True)

    在上述使用saver的代码中,我们还可以将参数放入Saver中实现指定存储参数的功能,可以指定存储变量名字和变量的对应关系,如下形式:

    saver = tf.train.Saver({'weight_':w, 'bias_':b})
    # saver = tf.train.Saver([w, b])

  • 相关阅读:
    数据库分库分表(sharding)系列(五) 一种支持自由规划无须数据迁移和修改路由代码的Sharding扩容方案
    数据库分库分表(sharding)系列(三) 关于使用框架还是自主开发以及sharding实现层面的考量
    docker的入门简介
    nginx方向代理详解及配置
    nginx配置文件详解
    nginx安装
    iptables防火墙
    服务器加载过程
    服务器
    操作系统
  • 原文地址:https://www.cnblogs.com/baby-lily/p/10924667.html
Copyright © 2011-2022 走看看