zoukankan      html  css  js  c++  java
  • Keras 训练一个单层全连接网络的线性回归模型

    1、准备环境,探索数据

    import numpy as np
    from keras.models import Sequential 
    from keras.layers import Dense
    import matplotlib.pyplot as plt
    
    # 创建数据集
    rng = np.random.RandomState(27) 
    X = np.linspace(-3, 5, 300)
    rng.shuffle(X)    # 将数据集随机化
    y = 0.5 * X + 1 + np.random.normal(0, 0.05, 300) # 假设真实模型为:y = 0.5X + 1
    
    # 绘制数据集
    plt.scatter(X, y, s=0.5)
    plt.show()

    2、准备数据训练模型

    # 划分训练集和测试集
    X_train, y_train = X[:400], y[:400]     
    X_test, y_test = X[-100:], y[-100:]       
    
    # 定义模型
    model = Sequential () # 用 Keras 序贯模型(Sequential)定义一个单输入单输出的模型 model
    model.add(Dense(output_dim=1, input_dim=1)) # 通过 add()方法一层, Dense 是全连接层,第一层需要定义输入
                                                
    # 设置模型参数
    model.compile(loss='mse', optimizer='sgd')  # 通过compile()方法选择损失函数(均方误差)和 优化器(随机梯度下降
    
    # 开始训练
    print('Training ==========')
    for step in range(301):
        cost = model.train_on_batch(X_train, y_train) # Keras 的 train_on_batch() 函数训练模型
        if step % 100 == 0:
            print('train cost: ', cost)
    

    3、测试训练好的模型

    print('
    Testing ==========')
    cost = model.evaluate(X_test, y_test, batch_size=40)
    print('test cost:', cost)
    W, b = model.layers[0].get_weights()    # 查看训练出的网络参数

    print('Weights=', W, ' biases=', b) # 由于网络只有一层,且每次训练的输入和输出只有一个节点,因此第一层训练出 y=WX+b 的模型,其中 W,b 为训练出的参数

     最终的测试 cost 为: 0.0026768923737108706

    4、可视化测试结果

    y_pred = model.predict(X_test)  # 用测试集进行预测
    plt.scatter(X_test, y_test, s=4)    # 绘制测试点图
    plt.plot(X_test, y_pred, lw=0.7)    # 绘制回归直线
    plt.show()

    。。。

  • 相关阅读:
    【题解】Red-Blue Graph Codeforces 1288F 上下界费用流
    【题解】The Magician HDU 6565 大模拟
    HAOI2018游记
    【题解】【THUSC 2016】成绩单 LOJ 2292 区间dp
    【题解】【雅礼集训 2017 Day5】远行 LOJ 6038 LCT
    【题解】Catering World Finals 2015 上下界费用流
    《无问西东...》
    为了世界的和平~一起上caioj~~~!
    新征程~起航!
    bzoj4240: 有趣的家庭菜园(树状数组+贪心思想)
  • 原文地址:https://www.cnblogs.com/shanger/p/12005542.html
Copyright © 2011-2022 走看看