zoukankan      html  css  js  c++  java
  • 非线性回归示例代码【多测师_王sir】

    一、非线性回归例子
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
     
    #生成随机点
    x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis] #返回均匀间隔的数字
    noise = np.random.normal(0,0.02,x_data.shape)
    y_data = np.square(x_data) + noise
     
    #定义两个placeholder
    x = tf.placeholder(tf.float32,[None,1])
    y = tf.placeholder(tf.float32,[None,1])
     
    #构建神经网络的中间层
    Weights_L1 = tf.Variable(tf.random_normal([1,10]))
    biases1 = tf.Variable(tf.zeros([1,10]))
    Wx_plus_b_L1 = tf.matmul(x,Weights_L1) + biases1  #注意multiply和matmul的区别:multiply矩阵维度必须相同,matmul矩阵相乘维度可以不同
    L1 = tf.nn.tanh(Wx_plus_b_L1)
     
    #构建神经网络的输出层
    Weights_L2 = tf.Variable(tf.random_normal([10,1]))
    biases2 = tf.Variable(tf.zeros([1,1]))
    Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biases2  #注意multiply和matmul的区别:multiply矩阵维度必须相同,matmul矩阵相乘维度可以不同
    prediction = tf.nn.tanh(Wx_plus_b_L2)
     
    #二次代价函数
    loss = tf.reduce_mean(tf.square(y-prediction))
    #梯度下降法法训练
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
     
    with tf.Session() as sess:
        #变量初始化
        sess.run(tf.global_variables_initializer())
        #训练
        for _ in range(2000):
            sess.run(train_step,feed_dict={x:x_data,y:y_data})
        #预测
        prediction_value = sess.run(prediction,feed_dict={x:x_data})
        #画图
        plt.figure()
        plt.scatter(x_data,y_data)
        plt.plot(x_data,prediction_value,'r-',lw=5)#红色实线 宽度为5
    plt.show()
    
    
    函数的一些解释:
    1.注意multiply和matmul的区别:multiply矩阵维度必须相同,matmul矩阵相乘维度可以不同
    2.random.normal解释: 高斯分布的概率密度函数                                            
    numpy中numpy.random.normal(loc=0.0, scale=1.0, size=None)
    
    参数的意义为:
    loc:float
    概率分布的均值,对应着整个分布的中心center
    scale:float
    概率分布的标准差,对应于分布的宽度,scale越大越矮胖,scale越小,越瘦高
    size:int or tuple of ints
    输出的shape,默认为None,只输出一个值
    我们更经常会用到np.random.randn(size)所谓标准正太分布(μ=0, σ=1),对应于np.random.normal(loc=0, scale=1, size)
    二、MNIST手写体数据集分类简单版本
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
     
    #导入数据集
    mnist = input_data.read_data_sets("MNIST_data",one_hot=True)#如果本地没有数据集,此语句会自动下载到对应的文件夹位置,不过网速较慢,不建议
    #每个批次的大小
    batch_size = 100
    #计算一共需要多少个批次
    n_batch = mnist.train.num_examples // batch_size
    #创建两个placeholder
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    #创建一个简单的神经网络
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    prediction = tf.nn.softmax(tf.matmul(x,W) + b)
    #二次代价函数
    loss = tf.reduce_mean(tf.square(y-prediction))
    #交叉熵损失函数,可以与二次代价函数对比一下那个效果好,运行时只能保留一个
    #loss = -tf.reduce_sum(y_*tf.log(y))
     
    #使用梯度下降法训练
    train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    #结果存放在一个布尔类型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#对比预测结果的标签是否一致,一致为True,不同为False
    #预测准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#将布尔型转化为0.0-1.0之间的数值,True为1.0,False为0.0
    #变量初始化
    init = tf.global_variables_initializer()
     
    with tf.Session() as sess:
        sess.run(init)
        for epoch in range(21):
            for batch in range(n_batch):
                batch_x,batch_y = mnist.train.next_batch(batch_size)#
                sess.run(train_step,feed_dict={x:batch_x,y:batch_y})
            acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
            print('Iter' + str(epoch) + ',Test Accuaracy' + str(acc))
    Iter20,Test Accuaracy0.9139

  • 相关阅读:
    Video Test Pattern Generator(7.0)软件调试记录
    阅读<Video Test Pattern Generator v7.0>笔记
    阅读<Vivado Design Suite Tutorial---Logic Simulation>笔记
    Modelsim使用流程---基于TCL命令的仿真
    BT.656 NTSC制式彩条生成模块(verilog)
    Video to SDI Tx Bridge模块video_data(SD-SDI)处理过程
    时钟分频方法---verilog代码
    手动按键复位程序(包含按键消抖)
    使用Vivado进行行为级仿真
    阅读OReilly.Web.Scraping.with.Python.2015.6笔记---Crawl
  • 原文地址:https://www.cnblogs.com/xiaoshubass/p/13280389.html
Copyright © 2011-2022 走看看