zoukankan      html  css  js  c++  java
  • 2.keras-构建基本网络实现非线性回归

    构建基本网络实现非线性回归

    1.加载显示数据集

    import tensorflow as tf
    import numpy as np
    import keras
    from keras.layers import *
    from keras.models import Sequential
    import matplotlib.pyplot as plt
    from keras.optimizers import SGD
    
    x_data = np.linspace(-0.5,0.5,200)
    noise = np.random.normal(0,0.02,x_data.shape)
    y_data = np.square(x_data) + noise
    
    # 显示
    plt.scatter(x_data,y_data)
    plt.show()

    2.构建网络输出结果

    # 构建顺序模型
    model = Sequential()
    # 在模型中添加一个全连接模型
    # 机构为1-10-1
    model.add(Dense(units=10,input_dim=1,activation='tanh'))
    model.add(Dense(units=1,activation='tanh')) #units=1,input_dim=1输入和输出都是一维的
    # 自定义SGD
    sgd = SGD(lr=0.3)
    model.compile(optimizer=sgd,
                  loss= 'mse')
    for step in range(3000):
        # 每次训练一个batch
        cost = model.train_on_batch(x_data,y_data)
        if step % 500 ==0:
            print('step:',step)
            print('cost',cost)
    # 打印权值和偏移项
    W,b = model.layers[0].get_weights()
    print('W:',W,'b',b)

    out:

    step: 0
    cost 0.066955164
    step: 500
    cost 0.0051592756
    step: 1000
    cost 0.019756123
    step: 1500
    cost 0.0018320761
    step: 2000
    cost 0.0007798174
    step: 2500
    cost 0.0005237385
    W: [[-0.06731744 0.8597639 0.4614085 0.02440587 -0.04702926 -0.03291976
    0.78343517 -0.0447227 1.1036808 1.4795449 ]] b [-0.04047519 0.27002558 -0.06009897 -0.20481145 -0.13842463 -0.27928182
    0.21476284 0.28802755 0.44497478 -0.59868914]

    3.预测并绘制预测结果

    # 进行预测值
    y_pred = model.predict(x_data)
    
    # 显示随机点
    plt.scatter(x_data,y_data)
    plt.plot(x_data,y_pred,'r-',lw=3)
    plt.show()

  • 相关阅读:
    国外程序猿整理的机器学习资源大全
    一个改动配置文件的linux shell script
    python高精度浮点型计算的诡异错误
    错误:'dict' object is not callable
    AssertionError while merging cells with xlwt (Python)
    Python => ValueError: unsupported format character 'Y' (0x59)
    [转]Python的3种格式化字符串方法
    python requirements使用方法
    conda虚拟环境实践
    迭代器中next()的用法
  • 原文地址:https://www.cnblogs.com/wigginess/p/13062719.html
Copyright © 2011-2022 走看看