zoukankan      html  css  js  c++  java
  • Tensorflow学习教程------非线性回归

    自己搭建神经网络求解非线性回归系数

    代码

    #coding:utf-8
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    #使用numpy 生成200个随机点
    x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis] #x_data:200行1列 数值在-0.5到0.5之间
    noise = np.random.normal(0,0.02,x_data.shape)#noise :200行1列
    y_data = np.square(x_data) + noise #y_data 200行1列
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32, [None,1]) #x:任意行 1列
    y = tf.placeholder(tf.float32, [None,1]) #y:任意行 1列
    
    
    #输入的是一个数 输出的也是一个数 因此输入层和输出层都是一个神经元
    #定义一个神经网络中间层 可以是任意个神经元 例如10个
    #定义神经网络中间层
    Weights_L1 = tf.Variable(tf.random_normal([1,10])) #1行10列
    biase_L1 = tf.Variable(tf.zeros([1,10])) # 1,10
    Wx_plus_b_L1 = tf.matmul(x, Weights_L1)+biase_L1 #x被正式赋值之后是200行1列 tf.matmul(x, Weights_L1)结果是200行10列 而biase_L1是1行10列 那么这俩怎么相加呢  在Python里面 200行10列的A向量+1行10列的B向量, 相当于给A向量每行加了B向量
    L1 = tf.nn.tanh(Wx_plus_b_L1) #使用双曲正切函数作为激活函数 
    
    #定义神经网络输出层
    Weights_L2 = tf.Variable(tf.random_normal([10,1]))
    biase_L2 = tf.Variable(tf.zeros([1,1]))
    Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biase_L2
    prediction = tf.nn.tanh(Wx_plus_b_L2)
    
    #二次代价函数
    loss = tf.reduce_mean(tf.square(y - prediction))
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    with tf.Session() as sess:
        #变量初始化
        sess.run(tf.global_variables_initializer())
        for _ in range(2000):
            sess.run(train_step, feed_dict={x:x_data,y:y_data})
        #迭代2000次之后所有的权重值都求出来了
        #获得预测值
        prediction_value = sess.run(prediction,feed_dict={x:x_data})
        #画图
        plt.figure()
        plt.scatter(x_data,y_data) #画出散点图
        plt.plot(x_data,prediction_value,'r-',lw=5) #画出折线图
        plt.show()
     

    结果

  • 相关阅读:
    Spring cloud:熔断器-客户端降级
    Spring cloud:熔断器-服务端降级
    Spring cloud:服务调用-声明式客户端访问
    Spring cloud:服务调用-服务名访问
    Spring cloud:服务调用-IP访问
    Spring cloud:支付微服务-服务注册
    Spring cloud:支付微服务-支付
    Spring cloud:支付微服务-创建
    python读取数据库mysql报错
    自动化测试注意事项
  • 原文地址:https://www.cnblogs.com/cnugis/p/7635004.html
Copyright © 2011-2022 走看看