zoukankan      html  css  js  c++  java
  • TensorFlow_Random_linear_modeul

    tensorflow_Linear_regression_demo

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    #產生一個含有隨機的線性模型
    x_data = np.random.rand(100) #均勻隨機產生100個點
    noise = 0.1*np.random.randn(100) #隨機項
    y_data = x_data * 1 + 3 + noise  # y_data = 1 * x_data + 3  + noise
    
    #宣告 tensorflow 中的變數
    # y = m*x + b
    m = tf.Variable(0.0)
    b = tf.Variable(0.0)
    y = m*x_data + b
    
    #代價函數 : loss = mean((y-y_data)^2)
    #其中 tf.reduce_mean 計算 tensor中每一個 dimension 的平均值
    # tf.square 計算 tensor 中每一個元的平方
    loss = tf.reduce_mean(tf.square(y_data - y))
    
    #Gradient desent method  (learning rate = 0.1)
    gd = tf.train.GradientDescentOptimizer(0.3)
    
    #最小化 代價函數 (operator)
    train = gd.minimize(loss)
    
    #初始化變數 operator
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        for step in range(50):
            sess.run(train)
            print("iter=", step, ", m=",sess.run(m), ", b=", sess.run(b))
        #將 train 後的結果存下來
        m = sess.run(m);
        b = sess.run(b);
        #繪製結果圖
        plt.figure()
        plt.scatter(x_data, y_data)
        plt.plot([0, 1], [b, m*1+b], '-r', lw=3)
        plt.show()
    


  • 相关阅读:
    jdk9 特性
    jdk8 特性
    Eclipse中Spring插件的安装
    C++避免程序运行完后窗口一闪而过的方法
    完全二叉树节点个数
    Shell 编写倒着的*三角形
    Drools源于规则引擎
    Spring Data MongoDB 三:基本文档查询(Query、BasicQuery
    docker环境搭建
    MyBatis根据数组、集合查询
  • 原文地址:https://www.cnblogs.com/hugeng007/p/9593280.html
Copyright © 2011-2022 走看看