zoukankan      html  css  js  c++  java
  • keras_非线性回归

    keras是tensorflow里的高阶API库,几行代码就可以构建一个神经网络。

    可以利用keras做非线性回归。

    生成一组数据

    1 import numpy as np
    2 import matplotlib.pyplot as plt
    3 x=np.linspace(-10,10,100)
    4 y=x**2+np.random.rand(100)*10
    5 fig1,ax1=plt.subplots()
    6 ax1.scatter(x,y)
    7 plt.show()

    完整程序:

     1 import tensorflow as tf
     2 import numpy as np
     3 import matplotlib.pyplot as plt
     4 
     5 #生成试验数据
     6 x=np.linspace(-10,10,100)  
     7 x=tf.convert_to_tensor(x)
     8 y=x**2+np.random.rand(100)*10
     9 
    10 #建立模型
    11 #设置层
    12 model=tf.keras.Sequential([
    13     #1-2层,输入层一个神经元,隐层100个神经元,激活函数sigmoid
    14     tf.keras.layers.Dense(100,activation='sigmoid',input_dim=1),
    15     tf.keras.layers.Dense(1)#3层,输出层一个神经元,输出一个数
    16 ])
    17 #编译模型
    18 ssgd=tf.keras.optimizers.SGD(lr=0.005)#设置学习率0.005
    19 #损失函数设为:mse(均方误差),优化器设为:SGD(梯度下降)
    20 model.compile(loss='mse',optimizer=ssgd)
    21 #训练模型
    22 model.fit(x,y,epochs=2000)#迭代2000次
    23 print('训练结束')
    24 #绘图
    25 a=model.predict(x)
    26 fig1,ax1=plt.subplots()
    27 ax1.scatter(x,y)
    28 fig,ax=plt.subplots()
    29 ax.scatter(x,y)
    30 ax.plot(x,a,'r')
    31 plt.show()

  • 相关阅读:
    npm改为淘宝镜像
    html中table中td内容换行
    git 切换文件夹路径
    git经常使用的命令
    day16
    day15
    day13
    day14
    day12
    day11
  • 原文地址:https://www.cnblogs.com/winterbear/p/12650539.html
Copyright © 2011-2022 走看看