zoukankan      html  css  js  c++  java
  • 用Keras搭建神经网络 简单模版(五)——RNN LSTM Regressor 循环神经网络

    # -*- coding: utf-8 -*-
    import numpy as np
    np.random.seed(1337)
    import matplotlib.pyplot as plt
    from keras.models import Sequential
    from keras.layers import LSTM,TimeDistributed,Dense
    from keras.optimizers import Adam
    
    BATCH_START = 0 
    TIME_STEPS = 20 
    BATCH_SIZE = 50
    INPUT_SIZE = 1 
    OUTPUT_SIZE = 1
    CELL_SIZE = 20
    LR = 0.006
    
    def get_batch():
        global BATCH_START,TIME_STEPS
        # xs shape(50,20,)
        #xs=np.arange(0,0+20*50).reshape(50,20)
        xs = np.arange(BATCH_START,BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE,TIME_STEPS)) / (10*np.pi)
        seq = np.sin(xs)
        res = np.cos(xs)
        BATCH_START += TIME_STEPS
        #plt.plot(xs[0,:],res[0,:],'r',xs[0,:],seq[0,:],'b--')
        #plt.show()
        return [seq[:,:,np.newaxis], res[:,:,np.newaxis],xs]
    
    #get_batch()
    #exit()
        
        
    model = Sequential()
    
    model.add(LSTM(output_dim=CELL_SIZE, 
                   return_sequences=True, # 每一个时间点都输出一个output
                   batch_input_shape=(BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
                   stateful = True,# batch和batch之间是否有联系
                   # 前一个batch的最后一步和后一个batch的第一步是有联系的
            )) 
    
    model.add(TimeDistributed(Dense(OUTPUT_SIZE))) # dense对每一个output连接,对每一个时间点都要计算
    
    adam = Adam(LR)
    model.compile(optimizer = adam,
                  loss = 'mse',)
    
    print('Training ------------')
    for step in range(501):
        # data shape = (batch_num,steps,inputs/output)
        X_batch, Y_batch, xs = get_batch()
        cost = model.train_on_batch(X_batch, Y_batch)
        pred = model.predict(X_batch,BATCH_SIZE)
        plt.plot(xs[0,:], Y_batch[0].flatten(),'r',xs[0,:],pred.flatten()[:TIME_STEPS],'b--')
        plt.ylim((-1.2,1.2))
        plt.draw()
        plt.pause(0.5)
        if step % 10 == 0:
            print('train cost',cost)
    
    
     

  • 相关阅读:
    鸽巢原理 学习笔记
    POJ 1811 Prime Test
    Ubuntu下pdf乱码问题解决方法
    POJ 基础数学
    SRM遇到的一个数论技巧——最大公约数和最小公倍数的关系
    计算几何初步模板
    矩阵快速幂 学习笔记
    ZOJ 2849 Attack of Panda Virus (优先队列 priority_queue)
    欧几里德算法和扩展欧几里德算法
    记部分HASH函数
  • 原文地址:https://www.cnblogs.com/caiyishuai/p/13270687.html
Copyright © 2011-2022 走看看