zoukankan      html  css  js  c++  java
  • RNN预测字母

    #字母预测:输入a预测出b,输入b预测出c,输入c预测出d,输入d预测出e,输入e预测出a
    #10000  a
    #01000  b
    #00100  c
    #00010  d
    #00001  e
    
    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.layers import Dense,SimpleRNN
    import matplotlib.pyplot as plt
    import os
    
    input_word='abcde'
    w_to_id={'a':0,'b':1,'c':2,'d':3,'e':4,}
    id_to_onehot={0:[1.,0.,0.,0.,0.],1:[0.,1.,0.,0.,0.],2:[0.,0.,1.,0.,0.],3:[0.,0.,0.,1.,0.],4:[0.,0.,0.,0.,1.]}
    x_train=[id_to_onehot[w_to_id['a']],id_to_onehot[w_to_id['b']],id_to_onehot[w_to_id['c']],id_to_onehot[w_to_id['d']],id_to_onehot[w_to_id['e']]]
    y_train=[w_to_id['b'],w_to_id['c'],w_to_id['d'],w_to_id['e'],w_to_id['a']]
    
    np.random.seed(7)
    np.random.shuffle(x_train)
    np.random.seed(7)
    np.random.shuffle(y_train)
    tf.random.set_seed(7)
    
    #使x_train符合SimpleRNN输入要求:[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
    #此处整个数据集送入,所以送入研样本数为len(x_train);输入1个字母出结果,循环核时间展开步数为1;表示独热码有5个输入特征,每个时间步输入特征个数为5
    x_train=np.reshape(x_train,(len(x_train),1,5))
    y_train=np.array(y_train)
    
    model=tf.keras.Sequential([SimpleRNN(3),Dense(5,activation='softmax')])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    
    checkpoint_save_path='./checkpoint/rnn_onehot_lprel.ckpt'
    
    if os.path.exists(checkpoint_save_path+'.index'):
        print('-------------------load the model--------------')
        model.load_weights(checkpoint_save_path)
    
    cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,
                                                   save_best_only=True,
                                                   monitor='loss')  #由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型
    
    history=model.fit(x_train,y_train,batch_size=32,epochs=50,callbacks=[cp_callback])
    
    model.summary()
    # print(model.trainable_variables)
    file = open('./rnn_weights.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    
    ###############################################    show   ###############################################
    
    # 显示训练集和验证集的acc和loss曲线
    acc = history.history['sparse_categorical_accuracy']
    loss = history.history['loss']
    
    
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.title('Training Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.title('Training Loss')
    plt.legend()
    plt.show()
    
    preNum=int(input('input the number of test alphabet'))
    
    for i in range(preNum):
        alphabet1=input('input test alphabet:')
        alphabet=[id_to_onehot[w_to_id[alphabet1]]]
        #使alphabet符合SimpleRNN输入要求[送入样本数,循环核时间展开步数,每个时间步输入特征个数]
        #使此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1;独热码有5个输入特征,每个时间步输入特征个数为5
        alphabet=np.reshape(alphabet,(1,1,5))
        reseult=model.predict([alphabet])
        pred=tf.argmax(reseult,axis=1)
        pred=int(pred)
        tf.print(alphabet1 + '->' + input_word[pred])
  • 相关阅读:
    【树】树的前序遍历(非递归)
    表单提交中的input、button、submit的区别
    利用setTimeout来实现setInterval
    Jquery动画操作的stop()函数
    Javascript实现简单的双向绑定
    Javascript观察者模式
    CSS reset
    【CSS3】background-origin和background-clip的区别
    :before和::before的区别
    JS实现瀑布流
  • 原文地址:https://www.cnblogs.com/python2/p/13610445.html
Copyright © 2011-2022 走看看