zoukankan      html  css  js  c++  java
  • Keras lstm 文本分类示例

    #基于IMDB数据集的简单文本分类任务

    #一层embedding层+一层lstm层+一层全连接层

    #基于Keras 2.1.1 Tensorflow 1.4.0

    代码:

     1 '''Trains an LSTM model on the IMDB sentiment classification task.
     2 The dataset is actually too small for LSTM to be of any advantage
     3 compared to simpler, much faster methods such as TF-IDF + LogReg.
     4 # Notes
     5 - RNNs are tricky. Choice of batch size is important,
     6 choice of loss and optimizer is critical, etc.
     7 Some configurations won't converge.
     8 - LSTM loss decrease patterns during training can be quite different
     9 from what you see with CNNs/MLPs/etc.
    10 '''
    11 from __future__ import print_function
    12 
    13 from keras.preprocessing import sequence
    14 from keras.models import Sequential
    15 from keras.layers import Dense, Embedding
    16 from keras.layers import LSTM
    17 from keras.datasets import imdb
    18 
    19 max_features = 20000
    20 maxlen = 80  # cut texts after this number of words (among top max_features most common words)
    21 batch_size = 32
    22 
    23 print('Loading data...')
    24 (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
    25 print(len(x_train), 'train sequences')
    26 print(len(x_test), 'test sequences')
    27 
    28 print('Pad sequences (samples x time)')
    29 x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
    30 x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
    31 print('x_train shape:', x_train.shape)
    32 print('x_test shape:', x_test.shape)
    33 
    34 print('Build model...')    
    35 model = Sequential()
    36 model.add(Embedding(max_features, 128))
    37 model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
    38 model.add(Dense(1, activation='sigmoid'))
    39 model.summary()
    40 
    41 # try using different optimizers and different optimizer configs
    42 model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
    43 
    44 print('Train...')
    45 model.fit(x_train, y_train,batch_size=batch_size,epochs=15,validation_data=(x_test, y_test))
    46 score, acc = model.evaluate(x_test, y_test,batch_size=batch_size)
    47 print('Test score:', score)
    48 print('Test accuracy:', acc)

    结果:

    Test accuracy: 0.81248
  • 相关阅读:
    无法加载文件或程序集“System.Net.Http,Version = 4.0.0.0,Culture = neutral,PublicKeyToken = b03f5f7f11d50a3a”
    Linux中安装Oracle jdk
    算符优先文法,中缀式求值,栈的典型应用
    数据结构之--双链表MyLinkedList
    数据结构之--单链表MyArrayList
    Java中的函数对象
    (11)连个工具类之间的比较4.Collections与Arrays
    javaList容器中容易忽略的知识点
    (27)回复泛型,注解、日志组件、枚举在实际项目中的使用
    无问西东,哪怕重头来过
  • 原文地址:https://www.cnblogs.com/cnXuYang/p/8992865.html
Copyright © 2011-2022 走看看