zoukankan      html  css  js  c++  java
  • TensorFlow_曲线拟合

    # coding:utf-8
    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    from Sigmoid import sigmoid
    
    x_data = np.arange(-2*np.pi,2*np.pi,0.1).reshape(-1,1)
    y_data = np.sin(x_data).reshape(-1,1)
    # x_data = sigmoid(x_data)
    # y_data = sigmoid(y_data)
    print(x_data.shape,y_data.shape)
    
    # 建立tensorflow模型
    x = tf.placeholder(tf.float32,[None,1])
    y = tf.placeholder(tf.float32,[None,1])
    # 首层
    w = tf.Variable(tf.random_normal([1,10]))
    b = tf.Variable(tf.zeros([1,10]))
    # 中间层
    w1 = tf.Variable(tf.random_normal([10,20]))
    b1 = tf.Variable(tf.zeros([1,1]))
    # 输出层
    w2 = tf.Variable(tf.random_normal([20,1]))
    b2 = tf.Variable(tf.zeros([1,1]))
    
    y_pred = tf.matmul(x,w)+b
    # 激活函数
    y_pred_1 = tf.nn.tanh(y_pred)
    yy  = tf.matmul(y_pred_1,w1)+b1
    y_pred_ = tf.nn.tanh(yy)
    y1 = tf.matmul(y_pred_,w2)+b2
    y2 = tf.nn.tanh(y1)
    #二次代价函数
    loss = tf.reduce_mean(tf.square(y-y2))
    # 训练方法:梯度下降法
    train_model = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    #结果存放在一个布尔型列表中
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y2,1))#argmax返回一维张量中最大的值所在的位置
    #求准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    # 初始化变量
    inint = tf.global_variables_initializer()
    # 开始训练
    with tf.Session() as sess:
        sess.run(inint)
        for i in range(10000):
            sess.run(train_model,feed_dict={x:x_data,y:y_data})
            if i%1000==0:
                auc = sess.run(accuracy,feed_dict={x:x_data,y:y_data})
                print('迭代次数:%d'%i,'auc:%d'%auc,' 损失函数(loss):',sess.run(loss,feed_dict={x:x_data,y:y_data}))
        y_ = sess.run(y2,feed_dict={x:x_data})
        sess.close()
    
        plt.figure('tensorflow',figsize=(12,6))
        plt.scatter(x_data, y_data,label='sin(x)的值')
        plt.plot(x_data,y_,'r',linewidth=1,label='tensorflow拟合值')
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为SimHei显示中文
        plt.rcParams['axes.unicode_minus'] = False  # 设置正常显示符号
        plt.title('tensorflow实现y=sin(x)拟合')
        plt.xlabel('x-values',{'size':15})
        plt.ylabel('y-values-sin(x)',{'size':15})
        plt.legend(loc='upper right')
        plt.show()
  • 相关阅读:
    JAVA 基本数据类型长度
    字符编码详解
    几种编码方式
    Java1.5泛型指南中文版(Java1.5 Generic Tutorial)
    java泛型小问题
    java中的equals()方法
    Java泛型中E、T、K、V等的含义
    数据库的基本操作
    Mysql数据类型简介(大概了解)
    [BZOJ 2007] 海拔
  • 原文地址:https://www.cnblogs.com/wuzaipei/p/9471467.html
Copyright © 2011-2022 走看看