zoukankan      html  css  js  c++  java
  • Tensorflow学习教程非线性回归

    自己搭建神经网络求解非线性回归系数

    代码

    复制代码
    #coding:utf-8
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    #使用numpy 生成200个随机点
    x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis] #x_data:200行1列 数值在-0.5到0.5之间
    noise = np.random.normal(0,0.02,x_data.shape)#noise :200行1列
    y_data = np.square(x_data) + noise #y_data 200行1列
    
    #定义两个placeholder
    x = tf.placeholder(tf.float32, [None,1]) #x:任意行 1列
    y = tf.placeholder(tf.float32, [None,1]) #y:任意行 1列
    
    
    #输入的是一个数 输出的也是一个数 因此输入层和输出层都是一个神经元
    #定义一个神经网络中间层 可以是任意个神经元 例如10个
    #定义神经网络中间层
    Weights_L1 = tf.Variable(tf.random_normal([1,10])) #1行10列
    biase_L1 = tf.Variable(tf.zeros([1,10])) # 1,10
    Wx_plus_b_L1 = tf.matmul(x, Weights_L1)+biase_L1 #x被正式赋值之后是200行1列 tf.matmul(x, Weights_L1)结果是200行10列 而biase_L1是1行10列 那么这俩怎么相加呢  在Python里面 200行10列的A向量+1行10列的B向量, 相当于给A向量每行加了B向量
    L1 = tf.nn.tanh(Wx_plus_b_L1) #使用双曲正切函数作为激活函数 
    
    #定义神经网络输出层
    Weights_L2 = tf.Variable(tf.random_normal([10,1]))
    biase_L2 = tf.Variable(tf.zeros([1,1]))
    Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biase_L2
    prediction = tf.nn.tanh(Wx_plus_b_L2)
    
    #二次代价函数
    loss = tf.reduce_mean(tf.square(y - prediction))
    #使用梯度下降法
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    with tf.Session() as sess:
        #变量初始化
        sess.run(tf.global_variables_initializer())
        for _ in range(2000):
            sess.run(train_step, feed_dict={x:x_data,y:y_data})
        #迭代2000次之后所有的权重值都求出来了
        #获得预测值
        prediction_value = sess.run(prediction,feed_dict={x:x_data})
        #画图
        plt.figure()
        plt.scatter(x_data,y_data) #画出散点图
        plt.plot(x_data,prediction_value,'r-',lw=5) #画出折线图
        plt.show()
    复制代码
     

    结果

     
    分类: 深度学习
  • 相关阅读:
    [译]CasperJS,基于PhantomJS的工具包
    [译]JavaScript:typeof的用途
    [译]JavaScript写的一个quine程序
    [译]Ruby中的解构赋值
    [译]DOM:元素ID就是全局变量
    [译]ECMAScript 6中的集合类型,第一部分:Set
    [译]JavaScript:Array.prototype和[]的性能差异
    [译]Web Inspector开始支持CSS区域
    [译]JavaScript:反科里化"this"
    [译]JavaScript:用什么来缩进
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/9960772.html
Copyright © 2011-2022 走看看