zoukankan      html  css  js  c++  java
  • 【文本分类-中文】textRNN

    一、概述

    在英文分类的基础上,再看看中文分类的,是一种10分类问题(体育,科技,游戏,财经,房产,家居等)的处理。

    二、数据集合

    数据集为新闻,总共有四个数据文件,在/data/cnews目录下,包括内容如下图所示测试集,训练集和验证集,和单词表(最后的单词表cnews.vocab.txt可以不要,因为训练可以自动产生)。数据格式:前面为类别,后面为描述内容。

    训练数据地址:链接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取码:2r04

    其中训练集的格式:

    vocab.txt的格式:每个字一行,其中前面加上PAD。

    三、代码

    3.1 数据采集cnews_loader.py

        1     # coding: utf-8
        2     import sys
        3     from collections import Counter
        4     import numpy as np
        5     import tensorflow.contrib.keras as kr
        6     
        7     if sys.version_info[0] > 2:
        8         is_py3 = True
        9     else:
       10         reload(sys)
       11         sys.setdefaultencoding("utf-8")
       12         is_py3 = False
       13     
       14     def native_word(word, encoding='utf-8'):
       15         """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
       16         if not is_py3:
       17             return word.encode(encoding)
       18         else:
       19             return word
       20     
       21     def native_content(content):
       22         if not is_py3:
       23             return content.decode('utf-8')
       24         else:
       25             return content
       26     
       27     def open_file(filename, mode='r'):
       28         """
       29         常用文件操作,可在python2和python3间切换.
       30         mode: 'r' or 'w' for read or write
       31         """
       32         if is_py3:
       33             return open(filename, mode, encoding='utf-8', errors='ignore')
       34         else:
       35             return open(filename, mode)
       36     
       37     def read_file(filename):
       38         """读取文件数据"""
       39         contents, labels = [], []
       40         with open_file(filename) as f:
       41             for line in f:
       42                 try:
       43                     label, content = line.strip().split('	')
       44                     if content:
       45                         contents.append(list(native_content(content)))
       46                         labels.append(native_content(label))
       47                 except:
       48                     pass
       49         return contents, labels
       50     
       51     def build_vocab(train_dir, vocab_dir, vocab_size=5000):
       52         """根据训练集构建词汇表,存储"""
       53         data_train, _ = read_file(train_dir)
       54     
       55         all_data = []
       56         for content in data_train:
       57             all_data.extend(content)
       58         counter = Counter(all_data)
       59         count_pairs = counter.most_common(vocab_size - 1)
       60         words, _ = list(zip(*count_pairs))
       61         # 添加一个 <PAD> 来将所有文本pad为同一长度
       62         words = ['<PAD>'] + list(words)
       63         open_file(vocab_dir, mode='w').write('
    '.join(words) + '
    ')
       64     
       65     def read_vocab(vocab_dir):
       66         """读取词汇表"""
       67         # words = open_file(vocab_dir).read().strip().split('
    ')
       68         with open_file(vocab_dir) as fp:
       69             # 如果是py2 则每个值都转化为unicode
       70             words = [native_content(_.strip()) for _ in fp.readlines()]
       71         word_to_id = dict(zip(words, range(len(words))))
       72         return words, word_to_id
       73     
       74     def read_category():
       75         """读取分类目录,固定"""
       76         categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
       77         categories = [native_content(x) for x in categories]
       78         cat_to_id = dict(zip(categories, range(len(categories))))
       79         return categories, cat_to_id
       80     
       81     def to_words(content, words):
       82         """将id表示的内容转换为文字"""
       83         return ''.join(words[x] for x in content)
       84     
       85     def process_file(filename, word_to_id, cat_to_id, max_length=600):
       86         """将文件转换为id表示"""
       87         contents, labels = read_file(filename)
       88     
       89         data_id, label_id = [], []
       90         for i in range(len(contents)):
       91             data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
       92             label_id.append(cat_to_id[labels[i]])
       93         # 使用keras提供的pad_sequences来将文本pad为固定长度
       94         x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
       95         y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示
       96     
       97         return x_pad, y_pad
       98     
       99     def batch_iter(x, y, batch_size=64):
      100         """生成批次数据"""
      101         data_len = len(x)
      102         num_batch = int((data_len - 1) / batch_size) + 1
      103     
      104         indices = np.random.permutation(np.arange(data_len))
      105         x_shuffle = x[indices]
      106         y_shuffle = y[indices]
      107     
      108         for i in range(num_batch):
      109             start_id = i * batch_size
      110             end_id = min((i + 1) * batch_size, data_len)
      111             yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id

    3.2 模型搭建cnn_model.py

        1     #!/usr/bin/python
        2     # -*- coding: utf-8 -*-
        3     import tensorflow as tf
        4     
        5     class TRNNConfig(object):
        6         """RNN配置参数"""
        7         # 模型参数
        8         embedding_dim = 64      # 词向量维度
        9         seq_length = 600        # 序列长度
       10         num_classes = 10        # 类别数
       11         vocab_size = 5000       # 词汇表达小
       12         num_layers= 2           # 隐藏层层数
       13         hidden_dim = 128        # 隐藏层神经元
       14         rnn = 'gru'             # lstm 或 gru
       15         dropout_keep_prob = 0.8 # dropout保留比例
       16         learning_rate = 1e-3    # 学习率
       17         batch_size = 128         # 每批训练大小
       18         num_epochs = 10          # 总迭代轮次
       19         print_per_batch = 100    # 每多少轮输出一次结果
       20         save_per_batch = 10      # 每多少轮存入tensorboard
       21     
       22     class TextRNN(object):
       23         """文本分类,RNN模型"""
       24         def __init__(self, config):
       25             self.config = config
       26             # 三个待输入的数据
       27             self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
       28             self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
       29             self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
       30             self.rnn()
       31     
       32         def rnn(self):
       33             """rnn模型"""
       34             def lstm_cell():   # lstm核
       35                 return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)
       36             def gru_cell():  # gru核
       37                 return tf.contrib.rnn.GRUCell(self.config.hidden_dim)
       38             def dropout(): # 为每一个rnn核后面加一个dropout层
       39                 if (self.config.rnn == 'lstm'):
       40                     cell = lstm_cell()
       41                 else:
       42                     cell = gru_cell()
       43                 return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)
       44     
       45             # 词向量映射
       46             with tf.device('/cpu:0'):
       47                 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
       48                 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
       49     
       50             with tf.name_scope("rnn"):
       51                 # 多层rnn网络
       52                 cells = [dropout() for _ in range(self.config.num_layers)]
       53                 rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
       54     
       55                 _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
       56                 last = _outputs[:, -1, :]  # 取最后一个时序输出作为结果
       57     
       58             with tf.name_scope("score"):
       59                 # 全连接层,后面接dropout以及relu激活
       60                 fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
       61                 fc = tf.contrib.layers.dropout(fc, self.keep_prob)
       62                 fc = tf.nn.relu(fc)
       63                 # 分类器
       64                 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
       65                 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别
       66     
       67             with tf.name_scope("optimize"):
       68                 # 损失函数,交叉熵
       69                 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
       70                 self.loss = tf.reduce_mean(cross_entropy)
       71                 # 优化器
       72                 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
       73     
       74             with tf.name_scope("accuracy"):
       75                 # 准确率
       76                 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
       77                 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    3.3 运行代码run_cnn.py

        1     # coding: utf-8
        2     from __future__ import print_function
        3     import os
        4     import sys
        5     import time
        6     from datetime import timedelta
        7     import numpy as np
        8     import tensorflow as tf
        9     from sklearn import metrics
       10     from rnn_model import TRNNConfig, TextRNN
       11     from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
       12     
       13     base_dir = '../data/cnews'
       14     train_dir = os.path.join(base_dir, 'cnews.train.txt')
       15     test_dir = os.path.join(base_dir, 'cnews.test.txt')
       16     val_dir = os.path.join(base_dir, 'cnews.val.txt')
       17     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
       18     save_dir = '../checkpoints/textrnn'
       19     save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
       20     
       21     def get_time_dif(start_time):
       22         """获取已使用时间"""
       23         end_time = time.time()
       24         time_dif = end_time - start_time
       25         return timedelta(seconds=int(round(time_dif)))
       26     
       27     def feed_data(x_batch, y_batch, keep_prob):
       28         feed_dict = {
       29             model.input_x: x_batch,
       30             model.input_y: y_batch,
       31             model.keep_prob: keep_prob
       32         }
       33         return feed_dict
       34     
       35     def evaluate(sess, x_, y_):
       36         """评估在某一数据上的准确率和损失"""
       37         data_len = len(x_)
       38         batch_eval = batch_iter(x_, y_, 128)
       39         total_loss = 0.0
       40         total_acc = 0.0
       41         for x_batch, y_batch in batch_eval:
       42             batch_len = len(x_batch)
       43             feed_dict = feed_data(x_batch, y_batch, 1.0)
       44             y_pred_class,loss, acc = sess.run([model.y_pred_cls,model.loss, model.acc], feed_dict=feed_dict)
       45             total_loss += loss * batch_len
       46             total_acc += acc * batch_len
       47         return y_pred_class,total_loss / data_len, total_acc / data_len
       48     
       49     def train():
       50         print("Configuring TensorBoard and Saver...")
       51         # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
       52         tensorboard_dir = '../tensorboard/textrnn'
       53         if not os.path.exists(tensorboard_dir):
       54             os.makedirs(tensorboard_dir)
       55         tf.summary.scalar("loss", model.loss)
       56         tf.summary.scalar("accuracy", model.acc)
       57         merged_summary = tf.summary.merge_all()
       58         writer = tf.summary.FileWriter(tensorboard_dir)
       59         # 配置 Saver
       60         saver = tf.train.Saver()
       61         if not os.path.exists(save_dir):
       62             os.makedirs(save_dir)
       63     
       64         print("Loading training and validation data...")
       65         # 载入训练集与验证集
       66         start_time = time.time()
       67         x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
       68         x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
       69         time_dif = get_time_dif(start_time)
       70         print("Time usage:", time_dif)
       71     
       72         # 创建session
       73         session = tf.Session()
       74         session.run(tf.global_variables_initializer())
       75         writer.add_graph(session.graph)
       76         print('Training and evaluating...')
       77         start_time = time.time()
       78         total_batch = 0  # 总批次
       79         best_acc_val = 0.0  # 最佳验证集准确率
       80         last_improved = 0  # 记录上一次提升批次
       81         require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练
       82     
       83         flag = False
       84         for epoch in range(config.num_epochs):
       85             print('Epoch:', epoch + 1)
       86             batch_train = batch_iter(x_train, y_train, config.batch_size)
       87             for x_batch, y_batch in batch_train:
       88                 feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
       89     
       90                 if total_batch % config.save_per_batch == 0:
       91                     # 每多少轮次将训练结果写入tensorboard scalar
       92                     s = session.run(merged_summary, feed_dict=feed_dict)
       93                     writer.add_summary(s, total_batch)
       94     
       95                 if total_batch % config.print_per_batch == 0:
       96                     # 每多少轮次输出在训练集和验证集上的性能
       97                     feed_dict[model.keep_prob] = 1.0
       98                     loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
       99                     y_pred_class,loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
      100     
      101                     if acc_val > best_acc_val:
      102                         # 保存最好结果
      103                         best_acc_val = acc_val
      104                         last_improved = total_batch
      105                         saver.save(sess=session, save_path=save_path)
      106                         improved_str = '*'
      107                     else:
      108                         improved_str = ''
      109     
      110                     time_dif = get_time_dif(start_time)
      111                     msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' 
      112                           + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
      113                     print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
      114     
      115                 session.run(model.optim, feed_dict=feed_dict)  # 运行优化
      116                 total_batch += 1
      117     
      118                 if total_batch - last_improved > require_improvement:
      119                     # 验证集正确率长期不提升,提前结束训练
      120                     print("No optimization for a long time, auto-stopping...")
      121                     flag = True
      122                     break  # 跳出循环
      123             if flag:  # 同上
      124                 break
      125     
      126     def test():
      127         print("Loading test data...")
      128         start_time = time.time()
      129         x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
      130         session = tf.Session()
      131         session.run(tf.global_variables_initializer())
      132         saver = tf.train.Saver()
      133         saver.restore(sess=session, save_path=save_path)  # 读取保存的模型
      134     
      135         print('Testing...')
      136         y_pred,loss_test, acc_test = evaluate(session, x_test, y_test)
      137         msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
      138         print(msg.format(loss_test, acc_test))
      139     
      140         batch_size = 128
      141         data_len = len(x_test)
      142         num_batch = int((data_len - 1) / batch_size) + 1
      143     
      144         y_test_cls = np.argmax(y_test, 1)
      145         y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
      146         for i in range(num_batch):  # 逐批次处理
      147             start_id = i * batch_size
      148             end_id = min((i + 1) * batch_size, data_len)
      149             feed_dict = {
      150                 model.input_x: x_test[start_id:end_id],
      151                 model.keep_prob: 1.0
      152             }
      153             y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
      154     
      155         # 评估
      156         print("Precision, Recall and F1-Score...")
      157         print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
      158         # 混淆矩阵
      159         print("Confusion Matrix...")
      160         cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
      161         print(cm)
      162         time_dif = get_time_dif(start_time)
      163         print("Time usage:", time_dif)
      164     
      165     if __name__ == '__main__':
      166         print('Configuring RNN model...')
      167         config = TRNNConfig()
      168         if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
      169             build_vocab(train_dir, vocab_dir, config.vocab_size)
      170         categories, cat_to_id = read_category()
      171         words, word_to_id = read_vocab(vocab_dir)
      172         config.vocab_size = len(words)
      173         model = TextRNN(config)
      174         option='train'
      175         if option == 'train':
      176             train()
      177         else:
      178             test()

    3.4 预测predict.py

        1     # coding: utf-8
        2     from __future__ import print_function
        3     import os
        4     import tensorflow as tf
        5     import tensorflow.contrib.keras as kr
        6     from rnn_model import TRNNConfig, TextRNN
        7     from  cnews_loader import read_category, read_vocab
        8     try:
        9         bool(type(unicode))
       10     except NameError:
       11         unicode = str
       12     
       13     base_dir = '../data/cnews'
       14     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
       15     save_dir = '../checkpoints/textrnn'
       16     save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
       17     
       18     class RnnModel:
       19         def __init__(self):
       20             self.config = TRNNConfig()
       21             self.categories, self.cat_to_id = read_category()
       22             self.words, self.word_to_id = read_vocab(vocab_dir)
       23             self.config.vocab_size = len(self.words)
       24             self.model = TextRNN(self.config)
       25             self.session = tf.Session()
       26             self.session.run(tf.global_variables_initializer())
       27             saver = tf.train.Saver()
       28             saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
       29     
       30         def predict(self, message):
       31             # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
       32             content = unicode(message)
       33             data = [self.word_to_id[x] for x in content if x in self.word_to_id]
       34     
       35             feed_dict = {
       36                 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
       37                 self.model.keep_prob: 1.0
       38             }
       39             y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
       40             return self.categories[y_pred_cls[0]]
       41     
       42     if __name__ == '__main__':
       43         rnn_model = RnnModel()
       44         test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机',
       45                      '热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']
       46         for i in test_demo:
       47             print(rnn_model.predict(i))

    四、结果展示

    训练时长,接近2小时

       

    相关代码可见:https://github.com/yifanhunter/textClassifier_chinese

  • 相关阅读:
    LeetCode 1110. Delete Nodes And Return Forest
    LeetCode 473. Matchsticks to Square
    LeetCode 886. Possible Bipartition
    LeetCode 737. Sentence Similarity II
    LeetCode 734. Sentence Similarity
    LeetCode 491. Increasing Subsequences
    LeetCode 1020. Number of Enclaves
    LeetCode 531. Lonely Pixel I
    LeetCode 1091. Shortest Path in Binary Matrix
    LeetCode 590. N-ary Tree Postorder Traversal
  • 原文地址:https://www.cnblogs.com/yifanrensheng/p/13583446.html
Copyright © 2011-2022 走看看