zoukankan      html  css  js  c++  java
  • TensorFlow(四) 用TensorFlow实现弹性网络回归算法(多线性回归)

    弹性网络回归算法是综合lasso回归和岭回归的一种回归算法,通过在损失函数中增加L1正则和L2正则项,进而控制单个系数对结果的影响

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn import  datasets
    sess=tf.Session()
    #加载鸢尾花集
    iris=datasets.load_iris()
    #花瓣长度,花瓣宽度,花萼宽度 预测 花萼长度
    x_vals=np.array([ [x[1],x[2],x[3]] for x in iris.data])
    y_vals=np.array([y[0] for y in iris.data])
    
    learning_rate=0.001
    batch_size=50
    
    x_data=tf.placeholder(shape=[None,3],dtype=tf.float32)
    y_target=tf.placeholder(shape=[None,1],dtype=tf.float32)
    
    A=tf.Variable(tf.random_normal(shape=[3,1]))
    b=tf.Variable(tf.random_normal(shape=[1,1]))
    
    #增加线性模型y=Ax+b  x*a==>shape(None,1)+b==>shape(NOne,1)
    model_out=tf.add(tf.matmul(x_data,A),b)
    #参数1,2
    elastic_p1=tf.constant(1.)
    elastic_p2=tf.constant(1.)
    
    
    #声明损失函数 包含斜率的L1正则和L2正则。
    #创建正则项
    l1_a_loss=tf.reduce_mean(tf.abs(A))
    l2_a_loss=tf.reduce_mean(tf.square(A))
    e1_term=tf.multiply(elastic_p1,l1_a_loss)
    e2_term=tf.multiply(elastic_p2,l2_a_loss)
    #这里A是不规则的shape即3,1的数组形式  对应的loss也扩展成数组形式
    loss=tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target-model_out)),e1_term),e2_term),0)
    
    #初始化变量
    init=tf.global_variables_initializer()
    sess.run(init)
    
    
    #梯度下降
    my_opt=tf.train.GradientDescentOptimizer(learning_rate)
    train_step=my_opt.minimize(loss)
    
    #循环迭代
    loss_rec=[]
    for i in range(1000):
        rand_index=np.random.choice(len(x_vals),size=batch_size)
        #shape(None,3)
        rand_x= x_vals[rand_index]
        rand_y= np.transpose([y_vals[rand_index]])
        #运行
        sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y})
        temp_loss =sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})
    
        #添加记录
        loss_rec.append(temp_loss)
        #打印
        if (i+1)%250==0:
            print('Step: %d A=%s b=%s'%(i,str(sess.run(A)),str(sess.run(b))))
            print('Loss:%s'% str(temp_loss[0]))
    
    #弹性网络回归迭代图形
    plt.plot(loss_rec,'k-',label='Loss')
    plt.title('Loss per Generation')
    plt.xlabel('Generation')
    plt.ylabel(' loss ')
    plt.show()

  • 相关阅读:
    策略梯度(Policy Gradient)
    无约束优化问题
    有约束优化问题
    计算机网络学习资料
    为什么要用等效基带信号?
    通信网实验—话务量分析
    无感数据埋点(自定义注解+aop+异步)
    排序算法
    位运算常见操作
    数据库与缓存一致性的几种实现方式
  • 原文地址:https://www.cnblogs.com/x0216u/p/9173106.html
Copyright © 2011-2022 走看看