zoukankan      html  css  js  c++  java
  • rnn实现三位数加法的训练

    #!/usr/bin/env python
    # coding=utf-8
    
    from keras.models import Sequential
    from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent
    import numpy as np
    import string
    import random
    
    class CharacterTable(object):
    
        def __init__(self, maxlen):
            self.chars = string.digits + '+ '
            self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
            self.indice_chars = dict((i, c) for i, c in enumerate(self.chars))
            self.maxlen = maxlen
    
        def encode(self, strs, maxlen=None):
            maxlen = maxlen if maxlen else self.maxlen
            vec = np.zeros((maxlen, len(self.chars)))
            for i, c in enumerate(strs):
                vec[i, self.char_indices[c]] = 1
            return vec
    
        def decode(self, vec, calc_argmax=True):
            if calc_argmax:
                vec = vec.argmax(axis=-1)
            return ''.join(self.indice_chars[x] for x in vec)
    
    def gen_num():
        nums = random.sample('0123456789', random.randint(1, 3))
        return int(''.join(nums))
    
    MAXLEN = 7  # 3+3+1
    ctable = CharacterTable(MAXLEN)
    
    questions, expected = [], []
    seen = set()
    i = 0
    while i < 50000:
        a, b = gen_num(), gen_num()
        key = tuple(sorted((a, b)))
        if key in seen:
            continue
        seen.add(key)
        q = '{}+{}'.format(a, b)
        query = q + ' '*(7-len(q)) 
        ans = str(a+b)
        ans += ' ' * (4-len(ans))
    
        questions.append(query)
        expected.append(ans)
        i += 1
    print('total questions', len(questions))
    
    X = np.zeros((len(questions), MAXLEN, len(ctable.chars)), dtype=np.bool)
    y = np.zeros((len(questions), 4, len(ctable.chars)), dtype=np.bool)
    
    for i, sent in enumerate(questions):
        X[i] = ctable.encode(sent)
    
    for i, sent in enumerate(expected):
        y[i] = ctable.encode(sent, 4)
    
    model = Sequential()
    model.add(recurrent.LSTM(128, input_shape=(7, len(ctable.chars))))
    model.add(RepeatVector(4))
    model.add(recurrent.LSTM(128, return_sequences=True))
    model.add(recurrent.LSTM(128, return_sequences=True))
    
    model.add(TimeDistributed(Dense(len(ctable.chars))))
    model.add(Activation('softmax'))
    
    model.compile(loss='categorical_crossentropy',
                 optimizer='adam',
                 metrics=['accuracy'])
    
    model.fit(X, y, batch_size=64, nb_epoch=20, validation_split=0.02, verbose=2)
    
    # 测试看看
    for i in range(10):
        ind = np.random.randint(0, len(questions)-5)
        x_test, y_test = X[ind:ind+5], y[ind:ind+5]
        y_preds = model.predict_classes(x_test, verbose=0)
        print('Q', ctable.decode(x_test[0]))
        print('T', ctable.decode(y_test[0]))
        print('Pred', ctable.decode(y_preds[0], calc_argmax=False))
    
    
    json_string = model.to_json()
    with open('rnn_add_model.json', 'wb') as fw:
        fw.write(json_string)
    model.save_weights('rnn_add_model.h5')

    基本是模仿官网例子,精简了一点,训练约1h, 准确率99.6%

    每天一小步,人生一大步!Good luck~
  • 相关阅读:
    利用JBoss漏洞拿webshell方法
    jboss漏洞导致服务器中毒
    dubbo bug之 Please check registry access list (whitelist/blacklist)的分析与解决
    将list转为json字符串
    MySQL语句给字段值加1
    java int怎么转换为string
    HttpURLConnection如何添加请求头?
    eclipse下载egit插件,实现代码git同步问题
    eclipse编译项目用maven编译问题
    fastjson将java list转为json字符串
  • 原文地址:https://www.cnblogs.com/jkmiao/p/6337862.html
Copyright © 2011-2022 走看看