zoukankan      html  css  js  c++  java
  • TensorFlow实现简单线性回归示例代码

    # -*- coding: utf-8 -*-
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    def real_func():
        return
    
    
    def emperor():
        num_points = 1000
        vectors_set = []
        for i in range(num_points):
            x1 = np.random.normal(0.0, 0.55)
            y1 = x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
            vectors_set.append([x1, y1])
    
        x_data = [v[0] for v in vectors_set]
        y_data = [v[1] for v in vectors_set]
    
        # plt.scatter(x_data, y_data, c='r')
        # plt.show()
    
        W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
        b = tf.Variable(tf.zeros([1]), name='b')
        y = W * x_data + b
    
        loss = tf.reduce_mean(tf.square(y - y_data), name='loss')
        optimizer = tf.train.GradientDescentOptimizer(0.5)
        train = optimizer.minimize(loss, name='train')
        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init)
        # print('W=', sess.run(W), 'b=', sess.run(b), 'loss=', sess.run(loss))
    
    
    
        for step in range(20):
            sess.run(train)
            print('W=', sess.run(W), 'b=', sess.run(b), 'loss=', sess.run(loss))
        writer = tf.summary.FileWriter(r'C:UsersAdministratorDesktopmeatwicemeatwice1newCognition
    einforcement_learning
    ew_test_tensorflow/tmp', sess.graph)
    
    
    
        plt.scatter(x_data, y_data, c='r')
        plt.plot(x_data, sess.run(W) * x_data + sess.run(b))
        plt.show()
    
    
    
    if __name__ == "__main__":
        emperor()

    运行结果:

  • 相关阅读:
    BZOJ 1064 NOI2008 假面舞会
    p1186反素数(模板题)
    p1140【飞船控制站】
    p1164【立方体求和】
    p1103【base64编码】
    长沙集训(day不知道)伪总结(怕不是是一篇心得??)
    长沙集训day12
    长沙集训day10
    长沙集训day9
    长沙集训day8
  • 原文地址:https://www.cnblogs.com/zrmw/p/11572329.html
Copyright © 2011-2022 走看看