zoukankan      html  css  js  c++  java
  • TensorFlow(三) 用TensorFlow实现L2正则损失函数线性回归算法

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn import  datasets
    sess=tf.Session()
    #加载鸢尾花集
    iris=datasets.load_iris()
    #宽度 长度
    x_vals=np.array([x[3] for x in iris.data])
    y_vals=np.array([x[0] for x in iris.data])
    
    learning_rate=0.05
    batch_size=25
    
    x_data=tf.placeholder(shape=[None,1],dtype=tf.float32)
    y_data=tf.placeholder(shape=[None,1],dtype=tf.float32)
    
    A=tf.Variable(tf.random_normal(shape=[1,1]))
    b=tf.Variable(tf.random_normal(shape=[1,1]))
    
    #增加线性模型y=Ax+b  x*a==>shape(None,1)+b==>shape(NOne,1)
    model_out=tf.add(tf.matmul(x_data,A),b)
    #声明L2损失函数
    loss=tf.reduce_mean(tf.square(y_data-model_out))
    
    #初始化变量
    init=tf.global_variables_initializer()
    sess.run(init)
    
    #梯度下降
    my_opt=tf.train.GradientDescentOptimizer(learning_rate)
    train_step=my_opt.minimize(loss)
    
    #循环迭代
    loss_rec=[]
    for i in range(100):
        rand_index=np.random.choice(len(x_vals),size=batch_size)
        #shape(None,1)
        rand_x=np.transpose([ x_vals[rand_index] ])
        rand_y=np.transpose([ y_vals[rand_index] ])
    
        #运行
        sess.run(train_step,feed_dict={x_data:rand_x,y_data:rand_y})
        temp_loss =sess.run(loss,feed_dict={x_data:rand_x,y_data:rand_y})
    
        #添加记录
        loss_rec.append(temp_loss)
        #打印
        if (i+1)%25==0:
            print('Step: %d A=%s b=%s'%(i,str(sess.run(A)),str(sess.run(b))))
            print('Loss:%s'% str(temp_loss))
    #抽取系数
    [slope]=sess.run(A)
    print(slope)
    [intercept]=sess.run(b)
    best_fit=[]
    for i in x_vals:
        best_fit.append(slope*i+intercept)
    #x_vals shape(None,1)
    plt.plot(x_vals,y_vals,'o',label='Data')
    plt.plot(x_vals,best_fit,'r-',label='Best fit line',linewidth=3)
    plt.legend(loc='upper left')
    
    plt.xlabel('Pedal Width')
    plt.ylabel('Pedal Length')
    plt.show()
    #L2
    plt.plot(loss_rec,'k-',label='Loss')
    plt.title('L2 loss per Generation')
    plt.xlabel('Generation')
    plt.ylabel('L2 loss ')
    plt.show()

  • 相关阅读:
    SDN期末作业验收
    SDN第五次上机作业
    SDN第四次作业
    SDN第四次上机作业
    SDN第三次上机
    SDN第三次作业
    第二次SDN上机作业
    SDN第二次作业
    SDN第一次上机作业
    SDN第一次作业
  • 原文地址:https://www.cnblogs.com/x0216u/p/9170695.html
Copyright © 2011-2022 走看看