zoukankan      html  css  js  c++  java
  • tensorflow--非线性回归

    算法步骤:

    1. 给定训练样本,x_data和y_data

    2. 定义两个占位符分别接收输入x和输出y

    3. 中间层操作实际为:权值w与输入x矩阵相乘,加上偏差b后,得到中间层输出

    4. 使用tanh函数激活后传给输出层

    5. 输出层操作实际为:权值w与中间层结果矩阵相乘,加上偏差b后,得到输出层输出

    6. 使用tanh函数激活后得到最终结果

    7. 利用y的预测值,与实际的y求出它们间的平均方差,即损失值

    8. 最后使用梯度下降法进行训练,使loss最小化

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    
    # 生成一组离散点
    x_data = np.linspace(-0.5, 0.5, 200)[:, np.newaxis]
    noise = np.random.normal(0, 0.02, x_data.shape)
    y_data = np.square(x_data) + noise
    
    # 定义两个占位符
    x = tf.placeholder(tf.float32, [None, 1])
    y= tf.placeholder(tf.float32, [None, 1])
    
    # 中间层操作
    Weights_L1 = tf.Variable(tf.random_normal([1,10]))
    biases_L1 = tf.Variable(tf.zeros([1, 10]))
    L1 = tf.nn.tanh(tf.matmul(x, Weights_L1)+biases_L1)
    
    # 输出层操作
    Weights_L2 = tf.Variable(tf.random_normal([10,1]))
    biases_L2 = tf.Variable(tf.zeros([1,1]))
    prediction = tf.nn.tanh(tf.matmul(L1,Weights_L2)+biases_L2)
    
    # 计算损失率
    loss = tf.reduce_mean(tf.square(y-prediction))
    # 使用梯度下降法训练
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    
    with tf.Session() as sess:
        # 初始化所有变量
        sess.run(tf.global_variables_initializer())
        for i in range(1000):
            sess.run(train_step, feed_dict={x:x_data, y:y_data})
        prediction_value = sess.run(prediction, feed_dict={x:x_data, y:y_data})
    
        # 画图展示预测结果
        plt.figure()
        plt.scatter(x_data, y_data)
        plt.plot(x_data, prediction_value,'r-*', lw=5)
        plt.show()

    总结:

    1. tensorflow中训练模型前,必须需先初始化变量,否则会报错

    2. 激活函数除了tanh(),还有rule(),Sigmoid(),据说Leaky ReLU 、 PReLU 或者 Maxout效果更佳

    3. 梯度下降法中的学习率需小心设置,避免出现过的死亡神经元

    如tf.train.GradientDescentOptimizer(0.1).minimize(loss)中0.1为学习率
  • 相关阅读:
    1013团队Beta冲刺day3
    1013团队Beta冲刺day2
    1013团队Beta冲刺day1
    beta预备
    团队作业——系统设计
    个人技术博客(α)
    团队作业—预则立&&他山之石
    软工实践- 项目需求规格说明书
    软工第二次作业 团队选题报告
    结队作业-匹配
  • 原文地址:https://www.cnblogs.com/freeyouth/p/11622531.html
Copyright © 2011-2022 走看看