zoukankan      html  css  js  c++  java
  • tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域、tensorboard收集、模型保存与加载、自定义命令行参数

    1、知识点

    """
    1、训练过程:
            1、准备好特征和目标值
            2、建立模型,随机初始化权重和偏置; 模型的参数必须要使用变量
            3、求损失函数,误差为均方误差
            4、梯度下降去优化损失过程,指定学习率
       
    2、Tensorflow运算API:
            1、矩阵运算:tf.matmul(x,w)
            2、平方:tf.square(error)
            3、均值:tf.reduce_mean(error) 
            4、梯度下降API: tf.train.tf.train.GradientDescentOptimizer(learning_rate)
                    learning_rate:学习率
                    minimize(lose):优化最小损失
                    return:梯度下降op
    3、注意项:
            1、tf.Variable()中的trainable表示为变量在训练过程可变
            2、学习率设置很大时,可能会出现权重和偏置为NAN,这种现象表现叫梯度爆炸
                    解决方法:1、重新设计网络 2、调整学习率 3、使用梯度截断 4、使用激活函数
    
    4、变量作用域:主要用于tensorboard查看,同时使代码更加清晰  with tf.variable_scope("data"):
    
    5、添加权重、参数、损失值等在tensoroard观察的情况:
            1、收集tensor变量 tf.summary.scalar('losses', loss)、tf.summary.histogram('weight',weight)
            2、合并变量并写入事件文件:merged = tf.summary.merge_all() 
            3、运行合并的tensor:summary = sess.run(merged)、fileWriter.add_summary(summary,i)
    
    6、模型保存与加载: tf.train.Saver(var_list=None,max_to_keep=5)
            var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.
            max_to_keep:指示要保留的最近检查点文件的最大数量。
            创建新文件时,会删除较旧的文件。如果无或0,则保留所有
            检查点文件。默认为5(即保留最新的5个检查点文件。
            a)例如:saver.save(sess, '/tmp/ckpt/test/model')
                saver.restore(sess, '/tmp/ckpt/test/model')
                保存文件格式:checkpoint文件
                
            b)模型加载:
                if os.path.exists('./ckpt/checkpoint'):
                    saver.restore(sess,'./ckpt/model')
                    
    7、自定义命令行参数:
        1、首先定义有哪些参数需要在运行时候指定
        2、程序当中获取定义命令行参数
        3、运行 python *.py --max_step=500 --model_dir='./ckpt/model' 
        本例执行命令:python tensorflow实现线性回归.py --max_step=50 --model_dir="./ckpt/model"
    
    """

    2、代码

    # coding = utf-8
    import tensorflow as tf
    import  os
    
    #自定义命令行参数
    tf.app.flags.DEFINE_integer("max_step",100,"模型训练的步数")
    tf.app.flags.DEFINE_string("model_dir"," ","模型文件加载路径")
    #定义获取命令行参数名字
    FLAGS = tf.app.flags.FLAGS
    def myLinear():
        """
        自实现一个线性回归预测
        :return:
        """
        #定义作用域
        with tf.variable_scope("data"):
            #1、准备数据,特征值
            x = tf.random_normal([100,1],mean=1.75,stddev=0.5)
            #目标值。矩阵相乘,必须是二维的
            y_true = tf.matmul(x,[[0.7]])+0.8
    
        with tf.variable_scope("model"):
            #2、建立线性模型 y = wx+b ,随机给定w和b的值,必须定义成变量
            weight = tf.Variable(tf.random_normal([1,1],mean=0.0,stddev=1.0,name="w"))
            bias = tf.Variable(0.0,name="b")
            y_predict = tf.matmul(x,weight)+bias
    
        with tf.variable_scope("loss"):
            #3、建立损失函数,均方误差
            loss = tf.reduce_mean(tf.square(y_true-y_predict))
        with tf.variable_scope("optimizer"):
            #4、梯度下降优化损失
            train_op = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss)
    
    ##############模型保存################
        with tf.variable_scope("save_model"):
            saver = tf.train.Saver();
    
        # 初始化变量
        init_op = tf.global_variables_initializer()
    ####################收集变量#########################
        # 收集tensor变量
        tf.summary.scalar('losses', loss)
        tf.summary.histogram('weight',weight)
    
        #合并变量并写入事件文件
        merged = tf.summary.merge_all()
        #通过会话运行程序
        with tf.Session() as sess:
            #必须要运行初始化变量
            sess.run(init_op)
    
            #打印随机最先初始化的权重和偏置
            print("随机初始化的参数权重为:%f,偏置为:%f" % (weight.eval(),bias.eval()))
            # 建立事件文件
            fileWriter = tf.summary.FileWriter("./tmp", graph=sess.graph)
    
            ###########加载模型,覆盖之前的参数##############
            if os.path.exists('./ckpt/checkpoint'):
                #saver.restore(sess,'./ckpt/model')
                saver.restore(sess, FLAGS.model_dir)
    
    
            #循环优化
            for i in range(FLAGS.max_step):
                #运行优化
                sess.run(train_op)
                #运行合并的tensor
                summary = sess.run(merged)
                fileWriter.add_summary(summary,i)
                print("第%d次优化参数权重为:%f,偏置为:%f" % (i,weight.eval(), bias.eval()))
            ################模型保存##############
                # if i%1000==0:
                #     #saver.save(sess,'./ckpt/model')
                #     saver.save(sess,FLAGS.model_dir)
            saver.save(sess, FLAGS.model_dir)
        return None
    
    
    if __name__ == '__main__':
        myLinear()
  • 相关阅读:
    conda环境配置以及pyinstaller报错配置
    软件测试的艺术--读书笔记
    flex布局相关
    移动端特殊样式
    css3中的2D转换
    logo seo优化
    html5 简单的新特性
    css中溢出文字省略号方式
    css用户界面样式
    精灵图与字体图标相关
  • 原文地址:https://www.cnblogs.com/ywjfx/p/10910005.html
Copyright © 2011-2022 走看看