zoukankan      html  css  js  c++  java
  • 寒假学习进度12: 线性回归tensorflow2.0实现

    参考博客:https://blog.csdn.net/weixin_45665788/article/details/104919669

    import  matplotlib.pyplot as plt
    import  numpy as np
    import tensorflow as tf
    # 载入随机种子
    np.random.seed(5)
    #生成100个等差序列,每个值在-1 - 1 之间
    x_data = np.linspace(-1,1,1000)
    #y = 2x + 1 + 噪声,噪声的维度和x_Data一致
    y_data = 2 * x_data +1.0 +np.random.randn(*x_data.shape) * 0.4 #*表示把元组拆分为一个个单独的实参
    plt.scatter(x_data,y_data)
    plt.plot(x_data,2*x_data+1,color = 'red' ,linewidth = 3)
    
    #定义模型函数以及线性函数的斜率和截距
    def model(x,w,b):
        return tf.multiply(x,w)+b
    
    #设置损失函数,这里使用均方差作为损失函数
    def loss_fun(x,y,w,b):
        err = model(x,w,b)-y
        squared_err = tf.square(err)
        return tf.reduce_mean(squared_err)
    
    #返回梯度向量
    def grad(x,y,w,b):
        with tf.GradientTape() as tape:
            loss_ = loss_fun(x,y,w,b)
        return tape.gradient(loss_,[w,b])
    
    if __name__ == '__main__':
        #因为模型比较简单,因此超参的迭代次数设置的比较小
        # 构建线性函数的斜率和截距
        w = tf.Variable(np.random.randn(), tf.float32)
        b = tf.Variable(0.0, tf.float32)
        # 设置迭代次数和学习率
        train_epochs = 10
        learning_rate = 0.01
        loss = []
        count = 0
        display_count = 10  # 控制显示粒度的参数,每训练10个样本输出一次损失值
    
        # 开始训练,轮数为epoch,采用SGD随机梯度下降优化方法
        for epoch in range(train_epochs):
            for xs, ys in zip(x_data, y_data):
                # 计算损失,并保存本次损失计算结果
                loss_ = loss_fun(xs, ys, w, b)
                loss.append(loss_)
                # 计算当前[w,b]的梯度
                delta_w, delta_b = grad(xs, ys, w, b)
                change_w = delta_w * learning_rate
                change_b = delta_b * learning_rate
                w.assign_sub(change_w)
                b.assign_sub(change_b)
                # 训练步数加1
                count = count + 1
                if count % display_count == 0:
                    print('train epoch : ', '%02d' % (epoch + 1), 'step:%03d' % (count), 'loss= ', '{:.9f}'.format(loss_))
            # 完成一轮训练后,画图
            plt.plot(x_data, w.numpy() * x_data + b.numpy())
            plt.show()
    

      

  • 相关阅读:
    Leetcode 811. Subdomain Visit Count
    Leetcode 70. Climbing Stairs
    Leetcode 509. Fibonacci Number
    Leetcode 771. Jewels and Stones
    Leetcode 217. Contains Duplicate
    MYSQL安装第三步报错
    .net 开发WEB程序
    JDK版本问题
    打开ECLIPSE 报failed to load the jni shared library
    ANSI_NULLS SQL语句
  • 原文地址:https://www.cnblogs.com/yangqqq/p/14459774.html
Copyright © 2011-2022 走看看