zoukankan      html  css  js  c++  java
  • TensorFlow笔记-模型的保存,恢复,实现线性回归

    模型的保存

    tf.train.Saver(var_list=None,max_to_keep=5)

    •var_list:指定将要保存和还原的变量。它可以作为一个

    dict或一个列表传递.

    •max_to_keep:指示要保留的最近检查点文件的最大数量。

    创建新文件时,会删除较旧的文件。如果无或0,则保留所有

    检查点文件。默认为5(即保留最新的5个检查点文件。)

    saver = tf.train.Saver()
    saver.save(sess, "")

    模型的恢复

    恢复模型的方法是restore(sess, save_path),save_path是以前保存参数的路径,我们可以使用tf.train.latest_checkpoint来获取最近的检查点文件(也恶意直接写文件目录)

    if os.path.exists("tmp/ckpt/checkpoint"):
                saver.restore(sess,"")
                print("恢复模型")

    自定义命令行参数

    import tensorflow as tf
    
    FLAGS = tf.app.flags.FLAGS
    
    tf.app.flags.DEFINE_string('data_dir', '/tmp/tensorflow/mnist/input_data',
                               """数据集目录""")
    tf.app.flags.DEFINE_integer('max_steps', 2000,
                                """训练次数""")
    tf.app.flags.DEFINE_string('summary_dir', '/tmp/summary/mnist/convtrain',
                               """事件文件目录""")
    
    def main(argv):
        print(FLAGS.data_dir)
        print(FLAGS.max_steps)
        print(FLAGS.summary_dir)
        print(argv)
    
    
    if __name__=="__main__":
        tf.app.run()

    线性回归

    准备数据

    with tf.variable_scope("data"):
        # 1、准备数据,x 特征值 [100, 1]   y 目标值[100]
        x = tf.random_normal([100, 1], mean=1.75, stddev=0.5, name="x_data")
        # 矩阵相乘必须是二维的
         y_true = tf.matmul(x, [[0.7]]) + 0.8

    构建模型

    with tf.variable_scope("model"):
        # 2、建立线性回归模型 1个特征,1个权重, 一个偏置 y = x w + b
        # 随机给一个权重和偏置的值,让他去计算损失,然后再当前状态下优化
        # 用变量定义才能优化
        weight = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name="w")
        bias = tf.Variable(0.0, name="b")
        y_predict = tf.matmul(x, weight) + bias

    构造损失函数

    with tf.variable_scope("loss"):
       # 3、建立损失函数,均方误差
       loss = tf.reduce_mean(tf.square(y_true - y_predict))

    利用梯度下降

    with tf.variable_scope("optimizer"):
           # 4、梯度下降优化损失 leaning_rate: 0 ~ 1, 2, 3,5, 7, 10
           train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

    源码

    import tensorflow as tf
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    # 在这里立flag
    tf.app.flags.DEFINE_integer("max_step",100,"模型训练的步数")
    tf.app.flags.DEFINE_string("model_dir","tmp/summary/test","模型文件的加载路径")
    
    FLAGS = tf.app.flags.FLAGS
    def myregression():
        with tf.variable_scope("data"):
            x = tf.random_normal([100, 1], mean=1.75, stddev=0.5)
            y_true = tf.matmul(x, [[0.7]]) + 0.8
        with tf.variable_scope("model"):
            # 权重 trainable 指定权重是否随着session改变
            weight = tf.Variable(tf.random_normal([int(x.shape[1]), 1], mean=0, stddev=1), name="w")
            # 偏置项
            bias = tf.Variable(0.0, name='b')
            # 构造y函数
            y_predict = tf.matmul(x, weight) + bias
        with tf.variable_scope("loss"):
            # 定义损失函数
            loss = tf.reduce_mean(tf.square(y_true - y_predict))
        with tf.variable_scope("optimizer"):
            # 使用梯度下降进行求解
            train_op = tf.train.GradientDescentOptimizer(0.1).minimize((loss))
        # 1.收集tensor
        tf.summary.scalar("losses", loss)
        tf.summary.histogram("weights", weight)
        # 2.定义合并tensor的op
        merged = tf.summary.merge_all()
        # 定义一个保存模型的op
        saver = tf.train.Saver()
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            # import matplotlib.pyplot as plt
            # plt.scatter(x.eval(), y_true.eval())
            # plt.show()
            print("初始化的权重:%f,偏置项:%f" % (weight.eval(), bias.eval()))
            # 建立事件文件
            filewriter = tf.summary.FileWriter('./tmp/summary/test/', graph=sess.graph)
            # 加载模型
            if os.path.exists("tmp/ckpt/checkpoint"):
                saver.restore(sess,FLAGS.model_dir)
                print("加载")
            n = 0
            while loss.eval() > 1e-6:
                n += 1
                if(n==FLAGS.max_step):
                    break
                sess.run(train_op)
                summary = sess.run(merged)
                filewriter.add_summary(summary, n)
                print("第%d次权重:%f,偏置项:%f" % (n, weight.eval(), bias.eval()))
            saver.save(sess, FLAGS.model_dir)
        return weight, bias
    
    
    myregression()
    # x_min,x_max = np.min(x.eval()),np.max(x.eval())
    # tx = np.arange(x_min,x_max,100)

  • 相关阅读:
    两种方式创建Maven项目【方式二】
    两种方式创建Maven项目【方式一】
    《Java程序设计》第二周学习记录(1)
    《Java程序设计》第一周学习记录(2)
    《Java程序设计》第一周学习记录(1)
    《Java程序设计》第一周学习总结
    Python isinstance
    笔记:Struts2 Action 非泛型集合元素类型转换
    笔记:Struts2 输入校验
    笔记:Struts2 国际化
  • 原文地址:https://www.cnblogs.com/TimVerion/p/11224498.html
Copyright © 2011-2022 走看看