zoukankan      html  css  js  c++  java
  • Keras实现LSTM

    LSTM是优秀的循环神经网络(RNN)结构,而LSTM在结构上也比较复杂,对RNN和LSTM还稍有疑问的朋友可以参考:Recurrent Neural Networks vs LSTM

    这里我们将要使用Keras搭建LSTM.Keras封装了一些优秀的深度学习框架的底层实现,使用起来相当简洁,甚至不需要深度学习的理论知识,你都可以轻松快速的搭建你的深度学习网络,强烈推荐给刚入门深度学习的同学使用,当然我也是还没入门的那个。Keras:https://keras.io/,keras的backend有,theano,TensorFlow、CNTk,这里我使用的是TensorFlow。

    下面我们就开始搭建LSTM,实现mnist数据的分类。

    step 0 加载包和定义参数

    mnist的image是28*28的shape,我们定义LSTM的input为(28,),将image一行一行地输入到LSTM的cell中,这样time_step就是28,表示一个image有28行,LSTM的output是30个。
    from keras.datasets import mnist
    from keras.layers import Dense, LSTM
    from keras.utils import to_categorical
    from keras.models import Sequential
    
    #parameters for LSTM
    nb_lstm_outputs = 30  #神经元个数
    nb_time_steps = 28  #时间序列长度
    nb_input_vector = 28 #输入序列

    step 1 数据预处理

    特别注意label要使用one_hot encoding,x_train的shape(60000, 28,28)

    1 #data preprocessing: tofloat32, normalization, one_hot encoding
    2 (x_train, y_train), (x_test, y_test) = mnist.load_data()
    3 x_train = x_train.astype('float32')
    4 x_test = x_test.astype('float32')
    5 x_train /= 255
    6 x_test /= 255
    7 
    8 y_train = to_categorical(y_train, num_classes=10)
    9 y_test = to_categorical(y_test, num_classes=10)

    step 2 搭建模型

    keras搭建模型相当简单,只需要在Sequential容器中不断add新的layer就可以了。

    1 #build model
    2 model = Sequential()
    3 model.add(LSTM(units=nb_lstm_outputs, input_shape=(nb_time_steps, nb_input_vector)))
    4 model.add(Dense(10, activation='softmax'))

    step 3 compile

    模型compile,指定loss function, optimizer, metrics

    1 #compile:loss, optimizer, metrics
    2 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

     step 4 train

    模型训练,需要指定,epochs训练的轮次数,batch_size。

    1 #train: epcoch, batch_size
    2 model.fit(x_train, y_train, epochs=20, batch_size=128, verbose=1)

     step 5 evaluate

    可以使用model.summary()来查看你的神经网络的架构和参数量等信息。

    1 model.summary()
    2 
    3 score = model.evaluate(x_test, y_test,batch_size=128, verbose=1)
    4 print(score)

     我的最后结果是:

  • 相关阅读:
    如何测试一个网页登陆界面
    吞吐量(TPS)、QPS、并发数、响应时间(RT)概念
    postman接口案例
    接口定义
    socket网络编程
    分页读取文件内容
    hashlib,configparser,logging,模块
    python面向对象编程
    常用模块
    迭代器和生成器
  • 原文地址:https://www.cnblogs.com/yangmang/p/7530416.html
Copyright © 2011-2022 走看看