zoukankan      html  css  js  c++  java
  • 【509】NLP实战系列(六)—— 通过 LSTM 来做分类

    参考:LSTM层

    1. 语法

    keras.layers.recurrent.LSTM(units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0)

    2. 参数

    • units:输出维度

    • activation:激活函数,为预定义的激活函数名(参考激活函数

    • recurrent_activation: 为循环步施加的激活函数(参考激活函数

    • use_bias: 布尔值,是否使用偏置项

    • kernel_initializer:权值初始化方法,为预定义初始化方法名的字符串,或用于初始化权重的初始化器。参考initializers

    • recurrent_initializer:循环核的初始化方法,为预定义初始化方法名的字符串,或用于初始化权重的初始化器。参考initializers

    • bias_initializer:权值初始化方法,为预定义初始化方法名的字符串,或用于初始化权重的初始化器。参考initializers

    • kernel_regularizer:施加在权重上的正则项,为Regularizer对象

    • bias_regularizer:施加在偏置向量上的正则项,为Regularizer对象

    • recurrent_regularizer:施加在循环核上的正则项,为Regularizer对象

    • activity_regularizer:施加在输出上的正则项,为Regularizer对象

    • kernel_constraints:施加在权重上的约束项,为Constraints对象

    • recurrent_constraints:施加在循环核上的约束项,为Constraints对象

    • bias_constraints:施加在偏置上的约束项,为Constraints对象

    • dropout:0~1之间的浮点数,控制输入线性变换的神经元断开比例

    • recurrent_dropout:0~1之间的浮点数,控制循环状态的线性变换的神经元断开比例

    • 其他参数参考Recurrent的说明

    3. 具体实现 

    3.1 加载数据
    from keras.datasets import imdb
    from keras.preprocessing import sequence
    
    max_features = 10000  # number of words to consider as features
    maxlen = 500  # cut texts after this number of words (among top max_features most common words)
    batch_size = 32
    
    print('Loading data...')
    (input_train, y_train), (input_test, y_test) = imdb.load_data(num_words=max_features)
    print(len(input_train), 'train sequences')
    print(len(input_test), 'test sequences')
    
    print('Pad sequences (samples x time)')
    input_train = sequence.pad_sequences(input_train, maxlen=maxlen)
    input_test = sequence.pad_sequences(input_test, maxlen=maxlen)
    print('input_train shape:', input_train.shape)
    print('input_test shape:', input_test.shape)
    

      output:

    Loading data...
    25000 train sequences
    25000 test sequences
    Pad sequences (samples x time)
    input_train shape: (25000, 500)
    input_test shape: (25000, 500)

    3.2 数据训练 

    from keras.layers import LSTM
    
    model = Sequential()
    model.add(Embedding(max_features, 32))
    model.add(LSTM(32))
    model.add(Dense(1, activation='sigmoid'))
    
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy',
                  metrics=['acc'])
    history = model.fit(input_train, y_train,
                        epochs=10,
                        batch_size=128,
                        validation_split=0.2)
    

      outputs:

    Train on 20000 samples, validate on 5000 samples
    Epoch 1/10
    20000/20000 [==============================] - 108s - loss: 0.5038 - acc: 0.7574 - val_loss: 0.3853 - val_acc: 0.8346
    Epoch 2/10
    20000/20000 [==============================] - 108s - loss: 0.2917 - acc: 0.8866 - val_loss: 0.3020 - val_acc: 0.8794
    Epoch 3/10
    20000/20000 [==============================] - 107s - loss: 0.2305 - acc: 0.9105 - val_loss: 0.3125 - val_acc: 0.8688
    Epoch 4/10
    20000/20000 [==============================] - 107s - loss: 0.2033 - acc: 0.9261 - val_loss: 0.4013 - val_acc: 0.8574
    Epoch 5/10
    20000/20000 [==============================] - 107s - loss: 0.1749 - acc: 0.9385 - val_loss: 0.3273 - val_acc: 0.8912
    Epoch 6/10
    20000/20000 [==============================] - 107s - loss: 0.1543 - acc: 0.9457 - val_loss: 0.3505 - val_acc: 0.8774
    Epoch 7/10
    20000/20000 [==============================] - 107s - loss: 0.1417 - acc: 0.9493 - val_loss: 0.4485 - val_acc: 0.8396
    Epoch 8/10
    20000/20000 [==============================] - 106s - loss: 0.1331 - acc: 0.9522 - val_loss: 0.3242 - val_acc: 0.8928
    Epoch 9/10
    20000/20000 [==============================] - 106s - loss: 0.1147 - acc: 0.9618 - val_loss: 0.4216 - val_acc: 0.8746
    Epoch 10/10
    20000/20000 [==============================] - 106s - loss: 0.1092 - acc: 0.9628 - val_loss: 0.3972 - val_acc: 0.8758
  • 相关阅读:
    CS001496 Gather data from web page with JavaScript, WebKit, and Qt
    中文 英特尔® 软件网络 blog
    CS001496 Gather data from web page with JavaScript, WebKit, and Qt
    Qt Port of WebKit ¶
    Category:Qt WebKit Nokia Developer Wiki
    Qt webKit可以做什么(四)——实现本地QObject和JavaScript交互
    看了潘爱民老师的关于smartcache for webkit的paper
    Qt 4.7: QWebInspector Class Reference
    CS001497 Add data to a web page with JavaScript, WebKit, and Qt
    geventcurl
  • 原文地址:https://www.cnblogs.com/alex-bn-lee/p/14197777.html
Copyright © 2011-2022 走看看