zoukankan      html  css  js  c++  java
  • TensorFlow(三) 用TensorFlow实现L2正则损失函数线性回归算法

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn import  datasets
    sess=tf.Session()
    #加载鸢尾花集
    iris=datasets.load_iris()
    #宽度 长度
    x_vals=np.array([x[3] for x in iris.data])
    y_vals=np.array([x[0] for x in iris.data])
    
    learning_rate=0.05
    batch_size=25
    
    x_data=tf.placeholder(shape=[None,1],dtype=tf.float32)
    y_data=tf.placeholder(shape=[None,1],dtype=tf.float32)
    
    A=tf.Variable(tf.random_normal(shape=[1,1]))
    b=tf.Variable(tf.random_normal(shape=[1,1]))
    
    #增加线性模型y=Ax+b  x*a==>shape(None,1)+b==>shape(NOne,1)
    model_out=tf.add(tf.matmul(x_data,A),b)
    #声明L2损失函数
    loss=tf.reduce_mean(tf.square(y_data-model_out))
    
    #初始化变量
    init=tf.global_variables_initializer()
    sess.run(init)
    
    #梯度下降
    my_opt=tf.train.GradientDescentOptimizer(learning_rate)
    train_step=my_opt.minimize(loss)
    
    #循环迭代
    loss_rec=[]
    for i in range(100):
        rand_index=np.random.choice(len(x_vals),size=batch_size)
        #shape(None,1)
        rand_x=np.transpose([ x_vals[rand_index] ])
        rand_y=np.transpose([ y_vals[rand_index] ])
    
        #运行
        sess.run(train_step,feed_dict={x_data:rand_x,y_data:rand_y})
        temp_loss =sess.run(loss,feed_dict={x_data:rand_x,y_data:rand_y})
    
        #添加记录
        loss_rec.append(temp_loss)
        #打印
        if (i+1)%25==0:
            print('Step: %d A=%s b=%s'%(i,str(sess.run(A)),str(sess.run(b))))
            print('Loss:%s'% str(temp_loss))
    #抽取系数
    [slope]=sess.run(A)
    print(slope)
    [intercept]=sess.run(b)
    best_fit=[]
    for i in x_vals:
        best_fit.append(slope*i+intercept)
    #x_vals shape(None,1)
    plt.plot(x_vals,y_vals,'o',label='Data')
    plt.plot(x_vals,best_fit,'r-',label='Best fit line',linewidth=3)
    plt.legend(loc='upper left')
    
    plt.xlabel('Pedal Width')
    plt.ylabel('Pedal Length')
    plt.show()
    #L2
    plt.plot(loss_rec,'k-',label='Loss')
    plt.title('L2 loss per Generation')
    plt.xlabel('Generation')
    plt.ylabel('L2 loss ')
    plt.show()

  • 相关阅读:
    scrapy中selenium的应用
    Django的锁和事务
    redis
    【leetcode】187. Repeated DNA Sequences
    【leetcode】688. Knight Probability in Chessboard
    【leetcode】576. Out of Boundary Paths
    【leetcode】947. Most Stones Removed with Same Row or Column
    【leetcode】948. Bag of Tokens
    【leetcode】946. Validate Stack Sequences
    【leetcode】945. Minimum Increment to Make Array Unique
  • 原文地址:https://www.cnblogs.com/x0216u/p/9170695.html
Copyright © 2011-2022 走看看