zoukankan      html  css  js  c++  java
  • tensorflow 使用 4 非线性回归

    # 输入一个 x 会计算出 y 值    y 是预测值,如果与 真的 y 值(y_data)接近就成功了 
    
    import tensorflow as tf
    import numpy as np
    # py 的画图工具
    import matplotlib.pyplot as plt
    
    # 用 numpy 生成个 200 个属性点  从 -0.5 到 0.5 间平均生成 200 个点
    #x_data = np.linspace(-0.5, 0.5, 200)      # 这只是生成了一维的数组
    # 用下边这句可以生成二维数组
    x_data = np.linspace(-0.5, 0.5, 200)[:, np.newaxis]
    
    # 生成随机值,和 x_data 的形状是一样的 ( 噪点 )
    noise = np.random.normal(0, 0.02, x_data.shape)
    
    # x_data 的平方+随机数
    y_data = np.square( x_data ) + noise
    
    
    # 定义二个占位符
    x = tf.placeholder( tf.float32, [None, 1] )  # [None, 1] 行不定,列只有一列
    y = tf.placeholder( tf.float32, [None, 1] )
    
        
    # 构建神经网络中间层                    一行十列
    Weights_L1 = tf.Variable( tf.random_normal([1, 10]))
    biases_L1 = tf.Variable( tf.zeros([1, 10]) )
    
    # 求出信号的总和          矩阵相乘,
    Wx_plus_b_L1 = tf.matmul(x, Weights_L1) + biases_L1
    # 中间层的输出
    L1 = tf.nn.tanh( Wx_plus_b_L1 )
    
    # 输出层                                    十行一列
    Weights_L2 = tf.Variable( tf.random.normal([10, 1]))
    biases_L2 = tf.Variable( tf.zeros([1, 1]) )
    # 求出信号的总和          矩阵相乘,
    Wx_plus_b_L2 = tf.matmul(L1, Weights_L2) + biases_L2
    # 得出最后的预测结果
    pred = tf.nn.tanh( Wx_plus_b_L2 )
    
    
    # 二次代价函数
    loss = tf.reduce_mean( tf.square(y - pred) )
    
    
    # 梯度下降法的优化器                           最小化代价函数
    train = tf.train.GradientDescentOptimizer( 0.2 ).minimize( loss )
    
    
    with tf.Session() as sess:
      # 初始化变量
      sess.run( tf.global_variables_initializer() )
      # 训练 2000 次
      for _ in range( 2000 ):
        sess.run( train, feed_dict={x:x_data, y:y_data} )
    
    
      # 得到预测值
      value = sess.run( pred, feed_dict={x:x_data} )
      # 用画图形式展现
      
      plt.figure()
      plt.scatter(x_data, y_data)
      plt.plot(x_data, value, 'r-', lw=5)
      plt.show()
      
    

      

  • 相关阅读:
    base64和Blob的相互转换
    限制文件上传的大小和尺寸
    git将本地项目提交到github
    vue-cli3创建项目时报错
    运行项目是node-sass报错的解决方法
    classList的使用
    将数组扁平化并去除其中重复数据,最终得到一个升序且不重复的数组
    移动端的图片放大
    js获取url中的参数
    HTML5-canvas
  • 原文地址:https://www.cnblogs.com/gdwz922/p/10637753.html
Copyright © 2011-2022 走看看