zoukankan      html  css  js  c++  java
  • keras实战教程三(文本生成)

    仅仅作个记录。

    文本生成想学习可以了解一下GPT2-Chinese 地址:https://github.com/Morizeyao/GPT2-Chinese

    数据地址:https://github.com/wb14123/couplet-dataset/releases/download/1.0/couplet.tar.gz

    模型介绍:模型的原理这里都不做介绍(有空可能会专门写个博客),可以自行查找资料。

    数据读取

    with open ("./couplet/train/in.txt","r") as f:
        data_in = f.read()
    with open ("./couplet/train/out.txt","r") as f:
        data_out = f.read()    
    data_in_list = data_in.split("
    ")
    data_out_list = data_out.split("
    ")
    data_in_list = [data.split() for data in data_in_list]
    data_out_list = [data.split() for data in data_out_list]

    构造字典

    import itertools
    words_all = list(itertools.chain.from_iterable(data_in_list))+list(itertools.chain.from_iterable(data_out_list))
    words_all = set(words_all)
    vocab = {j:i+1 for i ,j in enumerate(words_all)}
    vocab["unk"] = 0

    数据处理

    from keras.preprocessing.sequence import pad_sequences
    data_in2id = [[vocab.get(word,0) for word in sen] for sen in data_in_list]
    data_out2id = [[vocab.get(word,0) for word in sen] for sen in data_out_list]
    train_data = pad_sequences(data_in2id,100)
    train_label = pad_sequences(data_out2id,100)
    train_label_input = train_label.reshape(*train_label.shape, 1)

    构造模型

    from keras.models import Model,Sequential
    from keras.layers import GRU, Input, Dense, TimeDistributed, Activation, RepeatVector, Bidirectional
    from keras.layers import Embedding
    from keras.optimizers import Adam
    from keras.losses import sparse_categorical_crossentropy
    
    def seq2seq_model(input_length,output_sequence_length,vocab_size):
        model = Sequential()
        model.add(Embedding(input_dim=vocab_size,output_dim = 128,input_length=input_length))
        model.add(Bidirectional(GRU(128, return_sequences = False)))
        model.add(Dense(128, activation="relu"))
        model.add(RepeatVector(output_sequence_length))
        model.add(Bidirectional(GRU(128, return_sequences = True)))
        model.add(TimeDistributed(Dense(vocab_size, activation = 'softmax')))
        model.compile(loss = sparse_categorical_crossentropy, 
                      optimizer = Adam(1e-3))
        model.summary()
        return model
    model = seq2seq_model(train_data.shape[1],train_label.shape[1],len(vocab))

    模型训练预测

    model.fit(train_data,train_label_input, batch_size =32, epochs =1, validation_split = 0.2) 
    import numpy as np
    input_sen ="国破山河在"
    char2id = [vocab.get(i,0) for i in input_sen]
    input_data = pad_sequences([char2id],100)
    result = model.predict(input_data)[0][-len(input_sen):]
    result_label = [np.argmax(i) for i in result]
    dict_res = {i:j for j,i in vocab.items()}
    print([dict_res.get(i) for i in  result_label])

     参考 https://www.jianshu.com/p/d7a9dcd14456

  • 相关阅读:
    TypeError: translate() takes exactly one argument (2 given)
    matlab为long term visual tracking数据集生成groundtruth.txt
    Linux下为python3安装opencv
    tensorflow全连接层降维
    MDNet结果json文件转成long term visual tracking (oxuva)评估所需的csv文件的python脚本
    no module named caffe
    IIS短文件/文件夹泄露漏洞
    点击劫持漏洞
    WPF Combobox数据绑定 Binding
    关于Win10安装vs2013简体中文语言包无法安装的问题
  • 原文地址:https://www.cnblogs.com/pergrand/p/12967128.html
Copyright © 2011-2022 走看看