zoukankan      html  css  js  c++  java
  • keras 入门模型训练

    # -*- coding: utf-8 -*-
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.models import load_model
    
    import matplotlib.pyplot as plt
    import numpy as np
    
    np.random.seed(1)  # for reproducibility
    
    X = np.random.rand(200)
    np.random.shuffle(X)  # randomize the data
    Y = X + np.random.normal(0, 0.05, (200,))
    
    X_train, Y_train = X[:160], Y[:160]  # first 160 data points
    X_test, Y_test = X[160:], Y[160:]  # last 40 data points
    model = Sequential()
    
    model.add(Dense(output_dim=1, input_dim=1))
    
    model.compile(loss='mse', optimizer='sgd')
    print('test before save: ', model.predict(X_test[0:1]))
    for step in range(10000):
        # cost = model.train_on_batch(X_train, Y_train)
        cost = model.fit(X_train, Y_train, nb_epoch=1, batch_size=160)
    
    # save model
    model.save('my_model.h5')  # HDF5 file, you have to pip3 install h5py if don't have it
    del model  # deletes the existing model
    
    # load model
    model = load_model('my_model.h5')
    print('test after load: ', model.predict(X_test[0:1]))
    
    # 模型预测值
    predictY = model.predict(X[:])
    predictY= np.asarray(predictY)
    predictY = np.reshape(predictY,(200))
    
    # 绘图
    plt.figure('Accuracy')
    plt.plot(X,Y,'ro')  # plot绘制折线图
    plt.plot(X,predictY,'b^')
    plt.draw()  # 显示绘图
    plt.pause(20)  #显示20秒
    plt.savefig("Accuracy.jpg")  #保存图象
    plt.close()   #关闭图表

    红色的点是真实的数据分布,绿色的点是模型预测出来的数据,迭代300轮效果:



    800轮:


    1500轮:


    3000轮:


  • 相关阅读:
    JMS学习四(ActiveMQ消息过滤)
    JMS学习三(ActiveMQ消息的可靠性)
    JMS学习二(简单的ActiveMQ实例)
    JMS学习一(JMS介绍)
    Linux iostat监测IO状态
    git删除所有提交历史记录
    MySQL查看数据库相关信息
    Java面试通关要点汇总集
    java开发需掌握技能2
    java开发需掌握技能1
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9411756.html
Copyright © 2011-2022 走看看