zoukankan      html  css  js  c++  java
  • 基于 LSTM 轻松生成各种古诗

    整个过程分为以下步骤完成:

    1. 语料准备
    2. 语料预处理
    3. 模型参数配置
    4. 构建模型
    5. 训练模型
    6. 模型作诗
    7. 绘制模型网络结构图

    下面一步步来构建和训练一个会写诗的模型。

    第一,语料准备。一共四万多首古诗,每行一首诗,标题在预处理的时候已经去掉了。

    第二,文件预处理。首先,机器并不懂每个中文汉字代表的是什么,所以要将文字转换为机器能理解的形式,这里我们采用 One-Hot 的形式,这样诗句中的每个字都能用向量来表示,下面定义函数 preprocess_file() 来处理。

     1 puncs = [']', '[', '', '', '{', '}', '', '', '']
     2 
     3 
     4 def preprocess_file(Config):
     5     # 语料文本内容
     6     files_content = ''
     7     with open(Config.poetry_file, 'r', encoding='utf-8') as f:
     8         for line in f:
     9             # 每行的末尾加上"]"符号代表一首诗结束
    10             for char in puncs:
    11                 line = line.replace(char, "")
    12             files_content += line.strip() + "]"
    13 
    14     words = sorted(list(files_content))
    15     words.remove(']')
    16     counted_words = {}
    17     for word in words:
    18         if word in counted_words:
    19             counted_words[word] += 1
    20         else:
    21             counted_words[word] = 1
    22 
    23     # 去掉低频的字
    24     erase = []
    25     for key in counted_words:
    26         if counted_words[key] <= 2:
    27             erase.append(key)
    28     for key in erase:
    29         del counted_words[key]
    30     del counted_words[']']
    31     wordPairs = sorted(counted_words.items(), key=lambda x: -x[1])
    32 
    33     words, _ = zip(*wordPairs)
    34     # word到id的映射
    35     word2num = dict((c, i + 1) for i, c in enumerate(words))
    36     num2word = dict((i, c) for i, c in enumerate(words))
    37     word2numF = lambda x: word2num.get(x, 0)
    38     return word2numF, num2word, words, files_content

    在每行末尾加上 ] 符号是为了标识这首诗已经结束了。我们给模型学习的方法是,给定前六个字,生成第七个字,所以在后面生成训练数据的时候,会以6的跨度,1的步长截取文字,生成语料。如果出现了 ] 符号,说明 ] 符号之前的语句和之后的语句是两首诗里面的内容,两首诗之间是没有关联关系的,所以我们后面会舍弃掉包含 ] 符号的训练数据。

    第三,模型参数配置。预先定义模型参数和加载语料以及模型保存名称,通过类 Config 实现。

    1 class Config(object):
    2     poetry_file = 'poetry.txt'
    3     weight_file = 'poetry_model.h5'
    4     # 根据前六个字预测第七个字
    5     max_len = 6
    6     batch_size = 512
    7     learning_rate = 0.001

    第四,构建模型,通过 PoetryModel 类实现,类的代码结构如下:

     1  class PoetryModel(object):
     2         def __init__(self, config):
     3             pass
     4 
     5         def build_model(self):
     6             pass
     7 
     8         def sample(self, preds, temperature=1.0):
     9             pass
    10 
    11         def generate_sample_result(self, epoch, logs):
    12             pass
    13 
    14         def predict(self, text):
    15             pass
    16 
    17         def data_generator(self):
    18             pass
    19         def train(self):
    20             pass

    类中定义的方法具体实现功能如下:

    (1)init 函数定义,通过加载 Config 配置信息,进行语料预处理和模型加载,如果模型文件存在则直接加载模型,否则开始训练。

     1  def __init__(self, config):
     2             self.model = None
     3             self.do_train = True
     4             self.loaded_model = False
     5             self.config = config
     6 
     7             # 文件预处理
     8             self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)
     9             if os.path.exists(self.config.weight_file):
    10                 self.model = load_model(self.config.weight_file)
    11                 self.model.summary()
    12             else:
    13                 self.train()
    14             self.do_train = False
    15             self.loaded_model = True

    (2)build_model 函数主要用 Keras 来构建网络模型,这里使用 LSTM 的 GRU 来实现,当然直接使用 LSTM 也没问题。

     1 def build_model(self):
     2     '''建立模型'''
     3     input_tensor = Input(shape=(self.config.max_len,))
     4     embedd = Embedding(len(self.num2word) + 1, 300, input_length=self.config.max_len)(input_tensor)
     5     lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)
     6     dropout = Dropout(0.6)(lstm)
     7     lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)
     8     dropout = Dropout(0.6)(lstm)
     9     flatten = Flatten()(lstm)
    10     dense = Dense(len(self.words), activation='softmax')(flatten)
    11     self.model = Model(inputs=input_tensor, outputs=dense)
    12     optimizer = Adam(lr=self.config.learning_rate)
    13     self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    (3)sample 函数,在训练过程的每个 epoch 迭代中采样。

     1    def sample(self, preds, temperature=1.0):
     2             '''
     3             当temperature=1.0时,模型输出正常
     4             当temperature=0.5时,模型输出比较open
     5             当temperature=1.5时,模型输出比较保守
     6             在训练的过程中可以看到temperature不同,结果也不同
     7             '''
     8             preds = np.asarray(preds).astype('float64')
     9             preds = np.log(preds) / temperature
    10             exp_preds = np.exp(preds)
    11             preds = exp_preds / np.sum(exp_preds)
    12             probas = np.random.multinomial(1, preds, 1)
    13             return np.argmax(probas)

    (4)训练过程中,每个 epoch 打印出当前的学习情况。

     1 def generate_sample_result(self, epoch, logs):
     2     print("
    ==================Epoch {}=====================".format(epoch))
     3     for diversity in [0.5, 1.0, 1.5]:
     4         print("------------Diversity {}--------------".format(diversity))
     5         start_index = random.randint(0, len(self.files_content) - self.config.max_len - 1)
     6         generated = ''
     7         sentence = self.files_content[start_index: start_index + self.config.max_len]
     8         generated += sentence
     9         for i in range(20):
    10             x_pred = np.zeros((1, self.config.max_len))
    11             for t, char in enumerate(sentence[-6:]):
    12                 x_pred[0, t] = self.word2numF(char)
    13 
    14             preds = self.model.predict(x_pred, verbose=0)[0]
    15             next_index = self.sample(preds, diversity)
    16             next_char = self.num2word[next_index]
    17             generated += next_char
    18             sentence = sentence + next_char
    19         print(sentence)

    (5)predict 函数,用于根据给定的提示,来进行预测。

    根据给出的文字,生成诗句,如果给的 text 不到四个字,则随机补全。

     1 def predict(self, text):
     2         if not self.loaded_model:
     3             return
     4         with open(self.config.poetry_file, 'r', encoding='utf-8') as f:
     5             file_list = f.readlines()
     6         random_line = random.choice(file_list)
     7         # 如果给的text不到四个字,则随机补全
     8         if not text or len(text) != 4:
     9             for _ in range(4 - len(text)):
    10                 random_str_index = random.randrange(0, len(self.words))
    11                 text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '',
    12                                                                                                            ''] else self.num2word.get(
    13                     random_str_index + 1)
    14         seed = random_line[-(self.config.max_len):-1]
    15         res = ''
    16         seed = 'c' + seed
    17         for c in text:
    18             seed = seed[1:] + c
    19             for j in range(5):
    20                 x_pred = np.zeros((1, self.config.max_len))
    21                 for t, char in enumerate(seed):
    22                     x_pred[0, t] = self.word2numF(char)
    23                 preds = self.model.predict(x_pred, verbose=0)[0]
    24                 next_index = self.sample(preds, 1.0)
    25                 next_char = self.num2word[next_index]
    26                 seed = seed[1:] + next_char
    27             res += seed
    28         return res

    (6) data_generator 函数,用于生成数据,提供给模型训练时使用。

     1 def data_generator(self):
     2     i = 0
     3     while 1:
     4         x = self.files_content[i: i + self.config.max_len]
     5         y = self.files_content[i + self.config.max_len]
     6         puncs = [']', '[', '', '', '{', '}', '', '', '', ':']
     7         if len([i for i in puncs if i in x]) != 0:
     8             i += 1
     9             continue
    10         if len([i for i in puncs if i in y]) != 0:
    11             i += 1
    12             continue
    13         y_vec = np.zeros(
    14             shape=(1, len(self.words)),
    15             dtype=np.bool
    16         )
    17         y_vec[0, self.word2numF(y)] = 1.0
    18         x_vec = np.zeros(
    19             shape=(1, self.config.max_len),
    20             dtype=np.int32
    21         )
    22         for t, char in enumerate(x):
    23             x_vec[0, t] = self.word2numF(char)
    24         yield x_vec, y_vec
    25         i += 1

    (7)train 函数,用来进行模型训练,其中迭代次数 number_of_epoch ,是根据训练语料长度除以 batch_size 计算的,如果在调试中,想用更小一点的number_of_epoch ,可以自定义大小,把 train 函数的第一行代码注释即可。

     1 def train(self):
     2         #number_of_epoch = len(self.files_content) // self.config.batch_size
     3         number_of_epoch = 10
     4         if not self.model:
     5             self.build_model()
     6         self.model.summary()
     7         self.model.fit_generator(
     8             generator=self.data_generator(),
     9             verbose=True,
    10             steps_per_epoch=self.config.batch_size,
    11             epochs=number_of_epoch,
    12             callbacks=[
    13                 keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),
    14                 LambdaCallback(on_epoch_end=self.generate_sample_result)
    15             ]
    16         )

    第五,整个模型构建好以后,接下来进行模型训练。

     model = PoetryModel(Config)

    训练过程中的第1-2轮迭代:

    enter image description here

    训练过程中的第9-10轮迭代:

    enter image description here

    虽然训练过程写出的诗句不怎么能看得懂,但是可以看到模型从一开始标点符号都不会用 ,到最后写出了有一点点模样的诗句,能看到模型变得越来越聪明了。

    第六,模型作诗,模型迭代10次之后的测试,首先输入几个字,模型根据输入的提示,做出诗句。

        text = input("text:")
        sentence = model.predict(text)
        print(sentence)

    比如输入:小雨,模型做出的诗句为:

    输入:text:小雨

    结果:小妃侯里守。雨封即客寥。俘剪舟过槽。傲老槟冬绛。

    第七,绘制网络结构图。

    模型结构绘图,采用 Keras自带的功能实现:

        plot_model(model.model, to_file='model.png')

    得到的模型结构图如下:

    enter image description here

    本节使用 LSTM 的变形 GRU 训练出一个能作诗的模型,当然大家可以替换训练语料为歌词或者小说,让机器人自动创作不同风格的歌曲或者小说。

    参考文献以及推荐阅读:

    1. 基于 Keras 和 LSTM 的文本生成
  • 相关阅读:
    BLE 5协议栈-安全管理层
    BLE 5协议栈-通用属性规范层(GATT)
    BLE 5协议栈-属性协议层(ATT)
    BLE 5协议栈-逻辑链路控制与适配协议层(L2CAP)
    BLE 5协议栈-主机控制接口(HCI)
    BLE 5协议栈-直接测试模式
    BLE 5协议栈-链路层
    BLE 5协议栈-物理层
    名词缩写
    C#中数据库事务、存储过程基本用法
  • 原文地址:https://www.cnblogs.com/chen8023miss/p/11977277.html
Copyright © 2011-2022 走看看