zoukankan      html  css  js  c++  java
  • 假期第15天

    #创建一个神经网络层
    def add_layer(input,in_size,out_size,activation_function=None):
    """
    :param input: 数据输入
    :param in_size: 输入大小
    :param out_size: 输出大小
    :param activation_function: 激活函数(默认没有)
    :return:output:数据输出
    """
    Weight=tf.Variable(tf.random_normal([in_size,out_size]) )
    biases=tf.Variable(tf.zeros([1,out_size]) +0.1 )
    W_mul_x_plus_b=tf.matmul(input,Weight) + biases
    #根据是否有激活函数
    if activation_function == None:
    output=W_mul_x_plus_b
    else:
    output=activation_function(W_mul_x_plus_b)
    return output

    #创建一个具有输入层,隐藏层,输出层的三层神经网络,神经元个数分别为1,10,1
    x_data=np.linspace(-1,1,300)[:,np.newaxis] # 创建输入数据 np.newaxis分别是在列(第二维)上增加维度,原先是(300,)变为(300,1)
    noise=np.random.normal(0,0.05,x_data.shape)
    y_data=np.square(x_data)+1+noise # 创建输入数据对应的输出

    # print(x_data.shape)
    # x=np.linspace(-1,1,300)
    # print(x.shape)

    #定义输入数据
    xs=tf.placeholder(tf.float32,[None,1])
    ys=tf.placeholder(tf.float32,[None,1])

    #定义一个隐藏层
    hidden_layer1=add_layer(xs,1,10,activation_function=tf.nn.relu)
    #定义一个输出层
    prediction=add_layer(hidden_layer1,10,1,activation_function=None)

    #求解神经网络参数
    #1.定义损失函数
    loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction) ,reduction_indices=[1] ))
    #2.定义训练过程
    train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)

    init=tf.global_variables_initializer()
    sess=tf.Session()
    sess.run(init)


    # 绘制求解的曲线
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.scatter(x_data, y_data)
    plt.ion() #Turn interactive mode on 开启互动模式
    plt.show() #Display a figure


    #3.进行训练
    for i in range(1000):
    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
    if i%100==0:
    try:
    ax.lines.remove(lines[0]) #擦除之前的轨迹
    except Exception:
    pass

    print(sess.run(loss,feed_dict={xs:x_data,ys:y_data} ) )
    # 计算预测值
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    #print(prediction.shape)
    # 绘制预测值
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)
    # 关闭Session
    plt.pause(1000)

    #关闭sess
    sess.close()

  • 相关阅读:
    城市生态规划关键技术方法之三:城市生态系统服务功能理论与方法
    AE中用矢量数据剪裁栅格
    城市生态规划的基本原理与程序
    城市生态规划关键技术方法之一:生态系统健康理论与方法
    城市生态规划的基本概念、内容与方法
    终于找到使用Sql Server Management Studio导致蓝屏的罪魁祸首了
    保证相同类型的MDI子窗体只会被打开一次的方法
    甩掉数据字典,让Sql Server数据库也来一个自描述
    新鲜试用IE8 beta2
    让我目瞪口呆的program.exe
  • 原文地址:https://www.cnblogs.com/yyl141/p/14300593.html
Copyright © 2011-2022 走看看