#!/usr/bin/env python # coding=utf-8 from keras.models import Sequential from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent import numpy as np import string import random class CharacterTable(object): def __init__(self, maxlen): self.chars = string.digits + '+ ' self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) self.indice_chars = dict((i, c) for i, c in enumerate(self.chars)) self.maxlen = maxlen def encode(self, strs, maxlen=None): maxlen = maxlen if maxlen else self.maxlen vec = np.zeros((maxlen, len(self.chars))) for i, c in enumerate(strs): vec[i, self.char_indices[c]] = 1 return vec def decode(self, vec, calc_argmax=True): if calc_argmax: vec = vec.argmax(axis=-1) return ''.join(self.indice_chars[x] for x in vec) def gen_num(): nums = random.sample('0123456789', random.randint(1, 3)) return int(''.join(nums)) MAXLEN = 7 # 3+3+1 ctable = CharacterTable(MAXLEN) questions, expected = [], [] seen = set() i = 0 while i < 50000: a, b = gen_num(), gen_num() key = tuple(sorted((a, b))) if key in seen: continue seen.add(key) q = '{}+{}'.format(a, b) query = q + ' '*(7-len(q)) ans = str(a+b) ans += ' ' * (4-len(ans)) questions.append(query) expected.append(ans) i += 1 print('total questions', len(questions)) X = np.zeros((len(questions), MAXLEN, len(ctable.chars)), dtype=np.bool) y = np.zeros((len(questions), 4, len(ctable.chars)), dtype=np.bool) for i, sent in enumerate(questions): X[i] = ctable.encode(sent) for i, sent in enumerate(expected): y[i] = ctable.encode(sent, 4) model = Sequential() model.add(recurrent.LSTM(128, input_shape=(7, len(ctable.chars)))) model.add(RepeatVector(4)) model.add(recurrent.LSTM(128, return_sequences=True)) model.add(recurrent.LSTM(128, return_sequences=True)) model.add(TimeDistributed(Dense(len(ctable.chars)))) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.fit(X, y, batch_size=64, nb_epoch=20, validation_split=0.02, verbose=2) # 测试看看 for i in range(10): ind = np.random.randint(0, len(questions)-5) x_test, y_test = X[ind:ind+5], y[ind:ind+5] y_preds = model.predict_classes(x_test, verbose=0) print('Q', ctable.decode(x_test[0])) print('T', ctable.decode(y_test[0])) print('Pred', ctable.decode(y_preds[0], calc_argmax=False)) json_string = model.to_json() with open('rnn_add_model.json', 'wb') as fw: fw.write(json_string) model.save_weights('rnn_add_model.h5')
基本是模仿官网例子,精简了一点,训练约1h, 准确率99.6%