zoukankan      html  css  js  c++  java
  • tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域、tensorboard收集、模型保存与加载、自定义命令行参数

    1、知识点

    """
    1、训练过程:
            1、准备好特征和目标值
            2、建立模型,随机初始化权重和偏置; 模型的参数必须要使用变量
            3、求损失函数,误差为均方误差
            4、梯度下降去优化损失过程,指定学习率
       
    2、Tensorflow运算API:
            1、矩阵运算:tf.matmul(x,w)
            2、平方:tf.square(error)
            3、均值:tf.reduce_mean(error) 
            4、梯度下降API: tf.train.tf.train.GradientDescentOptimizer(learning_rate)
                    learning_rate:学习率
                    minimize(lose):优化最小损失
                    return:梯度下降op
    3、注意项:
            1、tf.Variable()中的trainable表示为变量在训练过程可变
            2、学习率设置很大时,可能会出现权重和偏置为NAN,这种现象表现叫梯度爆炸
                    解决方法:1、重新设计网络 2、调整学习率 3、使用梯度截断 4、使用激活函数
    
    4、变量作用域:主要用于tensorboard查看,同时使代码更加清晰  with tf.variable_scope("data"):
    
    5、添加权重、参数、损失值等在tensoroard观察的情况:
            1、收集tensor变量 tf.summary.scalar('losses', loss)、tf.summary.histogram('weight',weight)
            2、合并变量并写入事件文件:merged = tf.summary.merge_all() 
            3、运行合并的tensor:summary = sess.run(merged)、fileWriter.add_summary(summary,i)
    
    6、模型保存与加载: tf.train.Saver(var_list=None,max_to_keep=5)
            var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.
            max_to_keep:指示要保留的最近检查点文件的最大数量。
            创建新文件时,会删除较旧的文件。如果无或0,则保留所有
            检查点文件。默认为5(即保留最新的5个检查点文件。
            a)例如:saver.save(sess, '/tmp/ckpt/test/model')
                saver.restore(sess, '/tmp/ckpt/test/model')
                保存文件格式:checkpoint文件
                
            b)模型加载:
                if os.path.exists('./ckpt/checkpoint'):
                    saver.restore(sess,'./ckpt/model')
                    
    7、自定义命令行参数:
        1、首先定义有哪些参数需要在运行时候指定
        2、程序当中获取定义命令行参数
        3、运行 python *.py --max_step=500 --model_dir='./ckpt/model' 
        本例执行命令:python tensorflow实现线性回归.py --max_step=50 --model_dir="./ckpt/model"
    
    """

    2、代码

    # coding = utf-8
    import tensorflow as tf
    import  os
    
    #自定义命令行参数
    tf.app.flags.DEFINE_integer("max_step",100,"模型训练的步数")
    tf.app.flags.DEFINE_string("model_dir"," ","模型文件加载路径")
    #定义获取命令行参数名字
    FLAGS = tf.app.flags.FLAGS
    def myLinear():
        """
        自实现一个线性回归预测
        :return:
        """
        #定义作用域
        with tf.variable_scope("data"):
            #1、准备数据,特征值
            x = tf.random_normal([100,1],mean=1.75,stddev=0.5)
            #目标值。矩阵相乘,必须是二维的
            y_true = tf.matmul(x,[[0.7]])+0.8
    
        with tf.variable_scope("model"):
            #2、建立线性模型 y = wx+b ,随机给定w和b的值,必须定义成变量
            weight = tf.Variable(tf.random_normal([1,1],mean=0.0,stddev=1.0,name="w"))
            bias = tf.Variable(0.0,name="b")
            y_predict = tf.matmul(x,weight)+bias
    
        with tf.variable_scope("loss"):
            #3、建立损失函数,均方误差
            loss = tf.reduce_mean(tf.square(y_true-y_predict))
        with tf.variable_scope("optimizer"):
            #4、梯度下降优化损失
            train_op = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss)
    
    ##############模型保存################
        with tf.variable_scope("save_model"):
            saver = tf.train.Saver();
    
        # 初始化变量
        init_op = tf.global_variables_initializer()
    ####################收集变量#########################
        # 收集tensor变量
        tf.summary.scalar('losses', loss)
        tf.summary.histogram('weight',weight)
    
        #合并变量并写入事件文件
        merged = tf.summary.merge_all()
        #通过会话运行程序
        with tf.Session() as sess:
            #必须要运行初始化变量
            sess.run(init_op)
    
            #打印随机最先初始化的权重和偏置
            print("随机初始化的参数权重为:%f,偏置为:%f" % (weight.eval(),bias.eval()))
            # 建立事件文件
            fileWriter = tf.summary.FileWriter("./tmp", graph=sess.graph)
    
            ###########加载模型,覆盖之前的参数##############
            if os.path.exists('./ckpt/checkpoint'):
                #saver.restore(sess,'./ckpt/model')
                saver.restore(sess, FLAGS.model_dir)
    
    
            #循环优化
            for i in range(FLAGS.max_step):
                #运行优化
                sess.run(train_op)
                #运行合并的tensor
                summary = sess.run(merged)
                fileWriter.add_summary(summary,i)
                print("第%d次优化参数权重为:%f,偏置为:%f" % (i,weight.eval(), bias.eval()))
            ################模型保存##############
                # if i%1000==0:
                #     #saver.save(sess,'./ckpt/model')
                #     saver.save(sess,FLAGS.model_dir)
            saver.save(sess, FLAGS.model_dir)
        return None
    
    
    if __name__ == '__main__':
        myLinear()
  • 相关阅读:
    WampServer Mysql配置
    Java实现 蓝桥杯VIP 算法提高 陶陶摘苹果2
    Java实现 蓝桥杯VIP 算法提高 陶陶摘苹果2
    Java实现 蓝桥杯VIP 算法提高 陶陶摘苹果2
    Java实现 蓝桥杯VIP 算法提高 质因数2
    Java实现 蓝桥杯VIP 算法提高 质因数2
    Java实现 蓝桥杯VIP 算法提高 质因数2
    Java实现 蓝桥杯VIP 算法提高 质因数2
    Java实现 蓝桥杯VIP 算法提高 质因数2
    Java实现 蓝桥杯VIP 算法提高 前10名
  • 原文地址:https://www.cnblogs.com/ywjfx/p/10910005.html
Copyright © 2011-2022 走看看