zoukankan      html  css  js  c++  java
  • 1.keras-构建基本简单网络实现线性回归

    构建基本简单网络实现线性回归

    1.创建数据绘制散点图

    import keras
    import numpy as np
    import matplotlib.pyplot as plt
    from keras.models import Sequential
    from keras.layers import Dense
    import tensorflow as tf
    
    # 创建数据绘制散点图
    x_data = np.random.rand(100)
    noise = np.random.normal(0,0.01,x_data.shape)
    y_data = x_data * 0.1 + 0.2 + noise
    
    plt.scatter(x_data,y_data)
    plt.show()

    2.构建模型

    # 构建顺序模型
    model = Sequential()
    # 在模型中添加一个全连接模型
    model.add(Dense(units=1,input_dim=1)) #units=1,input_dim=1输入和输出都是一维的
    model.compile(optimizer='sgd',
                  loss= 'mse')
    for step in range(3000):
        # 每次训练一个batch
        cost = model.train_on_batch(x_data,y_data)
        if step % 500 ==0:
            print('step:',step)
            print('cost',cost)
    # 打印权值和偏移项
    W,b = model.layers[0].get_weights()
    print('W:',W,'b',b)

    out:

    step: 0
    cost 0.026886607
    step: 500
    cost 0.0005393094
    step: 1000
    cost 0.00020998158
    step: 1500
    cost 0.0001274022
    step: 2000
    cost 0.000106695006
    step: 2500
    cost 0.00010150281
    W: [[0.1016009]] b [0.19795756]

    3.预测并绘制预测结果

    # 进行预测值
    y_pred = model.predict(x_data)
    
    # 显示随机点
    plt.scatter(x_data,y_data)
    plt.plot(x_data,y_pred,'r-',lw=3)
    plt.show()

  • 相关阅读:
    盛最多水的容器
    字符串的排序
    整数拆分
    TCP和UDP编程
    旋转图像
    非递减数列
    不同路径2
    不同路径
    压缩拉伸图片
    Java对List分割及使用Spring多线程调用
  • 原文地址:https://www.cnblogs.com/wigginess/p/13062696.html
Copyright © 2011-2022 走看看