zoukankan      html  css  js  c++  java
  • 【文本分类-中文】textRNN

    一、概述

    在英文分类的基础上,再看看中文分类的,是一种10分类问题(体育,科技,游戏,财经,房产,家居等)的处理。

    二、数据集合

    数据集为新闻,总共有四个数据文件,在/data/cnews目录下,包括内容如下图所示测试集,训练集和验证集,和单词表(最后的单词表cnews.vocab.txt可以不要,因为训练可以自动产生)。数据格式:前面为类别,后面为描述内容。

    训练数据地址:链接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取码:2r04

    其中训练集的格式:

    vocab.txt的格式:每个字一行,其中前面加上PAD。

    三、代码

    3.1 数据采集cnews_loader.py

        1     # coding: utf-8
        2     import sys
        3     from collections import Counter
        4     import numpy as np
        5     import tensorflow.contrib.keras as kr
        6     
        7     if sys.version_info[0] > 2:
        8         is_py3 = True
        9     else:
       10         reload(sys)
       11         sys.setdefaultencoding("utf-8")
       12         is_py3 = False
       13     
       14     def native_word(word, encoding='utf-8'):
       15         """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
       16         if not is_py3:
       17             return word.encode(encoding)
       18         else:
       19             return word
       20     
       21     def native_content(content):
       22         if not is_py3:
       23             return content.decode('utf-8')
       24         else:
       25             return content
       26     
       27     def open_file(filename, mode='r'):
       28         """
       29         常用文件操作,可在python2和python3间切换.
       30         mode: 'r' or 'w' for read or write
       31         """
       32         if is_py3:
       33             return open(filename, mode, encoding='utf-8', errors='ignore')
       34         else:
       35             return open(filename, mode)
       36     
       37     def read_file(filename):
       38         """读取文件数据"""
       39         contents, labels = [], []
       40         with open_file(filename) as f:
       41             for line in f:
       42                 try:
       43                     label, content = line.strip().split('	')
       44                     if content:
       45                         contents.append(list(native_content(content)))
       46                         labels.append(native_content(label))
       47                 except:
       48                     pass
       49         return contents, labels
       50     
       51     def build_vocab(train_dir, vocab_dir, vocab_size=5000):
       52         """根据训练集构建词汇表,存储"""
       53         data_train, _ = read_file(train_dir)
       54     
       55         all_data = []
       56         for content in data_train:
       57             all_data.extend(content)
       58         counter = Counter(all_data)
       59         count_pairs = counter.most_common(vocab_size - 1)
       60         words, _ = list(zip(*count_pairs))
       61         # 添加一个 <PAD> 来将所有文本pad为同一长度
       62         words = ['<PAD>'] + list(words)
       63         open_file(vocab_dir, mode='w').write('
    '.join(words) + '
    ')
       64     
       65     def read_vocab(vocab_dir):
       66         """读取词汇表"""
       67         # words = open_file(vocab_dir).read().strip().split('
    ')
       68         with open_file(vocab_dir) as fp:
       69             # 如果是py2 则每个值都转化为unicode
       70             words = [native_content(_.strip()) for _ in fp.readlines()]
       71         word_to_id = dict(zip(words, range(len(words))))
       72         return words, word_to_id
       73     
       74     def read_category():
       75         """读取分类目录,固定"""
       76         categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
       77         categories = [native_content(x) for x in categories]
       78         cat_to_id = dict(zip(categories, range(len(categories))))
       79         return categories, cat_to_id
       80     
       81     def to_words(content, words):
       82         """将id表示的内容转换为文字"""
       83         return ''.join(words[x] for x in content)
       84     
       85     def process_file(filename, word_to_id, cat_to_id, max_length=600):
       86         """将文件转换为id表示"""
       87         contents, labels = read_file(filename)
       88     
       89         data_id, label_id = [], []
       90         for i in range(len(contents)):
       91             data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
       92             label_id.append(cat_to_id[labels[i]])
       93         # 使用keras提供的pad_sequences来将文本pad为固定长度
       94         x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
       95         y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示
       96     
       97         return x_pad, y_pad
       98     
       99     def batch_iter(x, y, batch_size=64):
      100         """生成批次数据"""
      101         data_len = len(x)
      102         num_batch = int((data_len - 1) / batch_size) + 1
      103     
      104         indices = np.random.permutation(np.arange(data_len))
      105         x_shuffle = x[indices]
      106         y_shuffle = y[indices]
      107     
      108         for i in range(num_batch):
      109             start_id = i * batch_size
      110             end_id = min((i + 1) * batch_size, data_len)
      111             yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id

    3.2 模型搭建cnn_model.py

        1     #!/usr/bin/python
        2     # -*- coding: utf-8 -*-
        3     import tensorflow as tf
        4     
        5     class TRNNConfig(object):
        6         """RNN配置参数"""
        7         # 模型参数
        8         embedding_dim = 64      # 词向量维度
        9         seq_length = 600        # 序列长度
       10         num_classes = 10        # 类别数
       11         vocab_size = 5000       # 词汇表达小
       12         num_layers= 2           # 隐藏层层数
       13         hidden_dim = 128        # 隐藏层神经元
       14         rnn = 'gru'             # lstm 或 gru
       15         dropout_keep_prob = 0.8 # dropout保留比例
       16         learning_rate = 1e-3    # 学习率
       17         batch_size = 128         # 每批训练大小
       18         num_epochs = 10          # 总迭代轮次
       19         print_per_batch = 100    # 每多少轮输出一次结果
       20         save_per_batch = 10      # 每多少轮存入tensorboard
       21     
       22     class TextRNN(object):
       23         """文本分类,RNN模型"""
       24         def __init__(self, config):
       25             self.config = config
       26             # 三个待输入的数据
       27             self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
       28             self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
       29             self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
       30             self.rnn()
       31     
       32         def rnn(self):
       33             """rnn模型"""
       34             def lstm_cell():   # lstm核
       35                 return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)
       36             def gru_cell():  # gru核
       37                 return tf.contrib.rnn.GRUCell(self.config.hidden_dim)
       38             def dropout(): # 为每一个rnn核后面加一个dropout层
       39                 if (self.config.rnn == 'lstm'):
       40                     cell = lstm_cell()
       41                 else:
       42                     cell = gru_cell()
       43                 return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)
       44     
       45             # 词向量映射
       46             with tf.device('/cpu:0'):
       47                 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
       48                 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
       49     
       50             with tf.name_scope("rnn"):
       51                 # 多层rnn网络
       52                 cells = [dropout() for _ in range(self.config.num_layers)]
       53                 rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
       54     
       55                 _outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
       56                 last = _outputs[:, -1, :]  # 取最后一个时序输出作为结果
       57     
       58             with tf.name_scope("score"):
       59                 # 全连接层,后面接dropout以及relu激活
       60                 fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
       61                 fc = tf.contrib.layers.dropout(fc, self.keep_prob)
       62                 fc = tf.nn.relu(fc)
       63                 # 分类器
       64                 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
       65                 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别
       66     
       67             with tf.name_scope("optimize"):
       68                 # 损失函数,交叉熵
       69                 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
       70                 self.loss = tf.reduce_mean(cross_entropy)
       71                 # 优化器
       72                 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
       73     
       74             with tf.name_scope("accuracy"):
       75                 # 准确率
       76                 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
       77                 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    3.3 运行代码run_cnn.py

        1     # coding: utf-8
        2     from __future__ import print_function
        3     import os
        4     import sys
        5     import time
        6     from datetime import timedelta
        7     import numpy as np
        8     import tensorflow as tf
        9     from sklearn import metrics
       10     from rnn_model import TRNNConfig, TextRNN
       11     from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
       12     
       13     base_dir = '../data/cnews'
       14     train_dir = os.path.join(base_dir, 'cnews.train.txt')
       15     test_dir = os.path.join(base_dir, 'cnews.test.txt')
       16     val_dir = os.path.join(base_dir, 'cnews.val.txt')
       17     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
       18     save_dir = '../checkpoints/textrnn'
       19     save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
       20     
       21     def get_time_dif(start_time):
       22         """获取已使用时间"""
       23         end_time = time.time()
       24         time_dif = end_time - start_time
       25         return timedelta(seconds=int(round(time_dif)))
       26     
       27     def feed_data(x_batch, y_batch, keep_prob):
       28         feed_dict = {
       29             model.input_x: x_batch,
       30             model.input_y: y_batch,
       31             model.keep_prob: keep_prob
       32         }
       33         return feed_dict
       34     
       35     def evaluate(sess, x_, y_):
       36         """评估在某一数据上的准确率和损失"""
       37         data_len = len(x_)
       38         batch_eval = batch_iter(x_, y_, 128)
       39         total_loss = 0.0
       40         total_acc = 0.0
       41         for x_batch, y_batch in batch_eval:
       42             batch_len = len(x_batch)
       43             feed_dict = feed_data(x_batch, y_batch, 1.0)
       44             y_pred_class,loss, acc = sess.run([model.y_pred_cls,model.loss, model.acc], feed_dict=feed_dict)
       45             total_loss += loss * batch_len
       46             total_acc += acc * batch_len
       47         return y_pred_class,total_loss / data_len, total_acc / data_len
       48     
       49     def train():
       50         print("Configuring TensorBoard and Saver...")
       51         # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
       52         tensorboard_dir = '../tensorboard/textrnn'
       53         if not os.path.exists(tensorboard_dir):
       54             os.makedirs(tensorboard_dir)
       55         tf.summary.scalar("loss", model.loss)
       56         tf.summary.scalar("accuracy", model.acc)
       57         merged_summary = tf.summary.merge_all()
       58         writer = tf.summary.FileWriter(tensorboard_dir)
       59         # 配置 Saver
       60         saver = tf.train.Saver()
       61         if not os.path.exists(save_dir):
       62             os.makedirs(save_dir)
       63     
       64         print("Loading training and validation data...")
       65         # 载入训练集与验证集
       66         start_time = time.time()
       67         x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
       68         x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
       69         time_dif = get_time_dif(start_time)
       70         print("Time usage:", time_dif)
       71     
       72         # 创建session
       73         session = tf.Session()
       74         session.run(tf.global_variables_initializer())
       75         writer.add_graph(session.graph)
       76         print('Training and evaluating...')
       77         start_time = time.time()
       78         total_batch = 0  # 总批次
       79         best_acc_val = 0.0  # 最佳验证集准确率
       80         last_improved = 0  # 记录上一次提升批次
       81         require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练
       82     
       83         flag = False
       84         for epoch in range(config.num_epochs):
       85             print('Epoch:', epoch + 1)
       86             batch_train = batch_iter(x_train, y_train, config.batch_size)
       87             for x_batch, y_batch in batch_train:
       88                 feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
       89     
       90                 if total_batch % config.save_per_batch == 0:
       91                     # 每多少轮次将训练结果写入tensorboard scalar
       92                     s = session.run(merged_summary, feed_dict=feed_dict)
       93                     writer.add_summary(s, total_batch)
       94     
       95                 if total_batch % config.print_per_batch == 0:
       96                     # 每多少轮次输出在训练集和验证集上的性能
       97                     feed_dict[model.keep_prob] = 1.0
       98                     loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
       99                     y_pred_class,loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
      100     
      101                     if acc_val > best_acc_val:
      102                         # 保存最好结果
      103                         best_acc_val = acc_val
      104                         last_improved = total_batch
      105                         saver.save(sess=session, save_path=save_path)
      106                         improved_str = '*'
      107                     else:
      108                         improved_str = ''
      109     
      110                     time_dif = get_time_dif(start_time)
      111                     msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' 
      112                           + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
      113                     print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
      114     
      115                 session.run(model.optim, feed_dict=feed_dict)  # 运行优化
      116                 total_batch += 1
      117     
      118                 if total_batch - last_improved > require_improvement:
      119                     # 验证集正确率长期不提升,提前结束训练
      120                     print("No optimization for a long time, auto-stopping...")
      121                     flag = True
      122                     break  # 跳出循环
      123             if flag:  # 同上
      124                 break
      125     
      126     def test():
      127         print("Loading test data...")
      128         start_time = time.time()
      129         x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
      130         session = tf.Session()
      131         session.run(tf.global_variables_initializer())
      132         saver = tf.train.Saver()
      133         saver.restore(sess=session, save_path=save_path)  # 读取保存的模型
      134     
      135         print('Testing...')
      136         y_pred,loss_test, acc_test = evaluate(session, x_test, y_test)
      137         msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
      138         print(msg.format(loss_test, acc_test))
      139     
      140         batch_size = 128
      141         data_len = len(x_test)
      142         num_batch = int((data_len - 1) / batch_size) + 1
      143     
      144         y_test_cls = np.argmax(y_test, 1)
      145         y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
      146         for i in range(num_batch):  # 逐批次处理
      147             start_id = i * batch_size
      148             end_id = min((i + 1) * batch_size, data_len)
      149             feed_dict = {
      150                 model.input_x: x_test[start_id:end_id],
      151                 model.keep_prob: 1.0
      152             }
      153             y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
      154     
      155         # 评估
      156         print("Precision, Recall and F1-Score...")
      157         print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
      158         # 混淆矩阵
      159         print("Confusion Matrix...")
      160         cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
      161         print(cm)
      162         time_dif = get_time_dif(start_time)
      163         print("Time usage:", time_dif)
      164     
      165     if __name__ == '__main__':
      166         print('Configuring RNN model...')
      167         config = TRNNConfig()
      168         if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
      169             build_vocab(train_dir, vocab_dir, config.vocab_size)
      170         categories, cat_to_id = read_category()
      171         words, word_to_id = read_vocab(vocab_dir)
      172         config.vocab_size = len(words)
      173         model = TextRNN(config)
      174         option='train'
      175         if option == 'train':
      176             train()
      177         else:
      178             test()

    3.4 预测predict.py

        1     # coding: utf-8
        2     from __future__ import print_function
        3     import os
        4     import tensorflow as tf
        5     import tensorflow.contrib.keras as kr
        6     from rnn_model import TRNNConfig, TextRNN
        7     from  cnews_loader import read_category, read_vocab
        8     try:
        9         bool(type(unicode))
       10     except NameError:
       11         unicode = str
       12     
       13     base_dir = '../data/cnews'
       14     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
       15     save_dir = '../checkpoints/textrnn'
       16     save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
       17     
       18     class RnnModel:
       19         def __init__(self):
       20             self.config = TRNNConfig()
       21             self.categories, self.cat_to_id = read_category()
       22             self.words, self.word_to_id = read_vocab(vocab_dir)
       23             self.config.vocab_size = len(self.words)
       24             self.model = TextRNN(self.config)
       25             self.session = tf.Session()
       26             self.session.run(tf.global_variables_initializer())
       27             saver = tf.train.Saver()
       28             saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
       29     
       30         def predict(self, message):
       31             # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
       32             content = unicode(message)
       33             data = [self.word_to_id[x] for x in content if x in self.word_to_id]
       34     
       35             feed_dict = {
       36                 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
       37                 self.model.keep_prob: 1.0
       38             }
       39             y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
       40             return self.categories[y_pred_cls[0]]
       41     
       42     if __name__ == '__main__':
       43         rnn_model = RnnModel()
       44         test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机',
       45                      '热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']
       46         for i in test_demo:
       47             print(rnn_model.predict(i))

    四、结果展示

    训练时长,接近2小时

       

    相关代码可见:https://github.com/yifanhunter/textClassifier_chinese

  • 相关阅读:
    C#对象深度克隆(转)
    .Net Core 图片文件上传下载(转)
    事件总线(Event Bus)知多少(转)
    深入理解C#:编程技巧总结(一)(转)
    asp.net core源码飘香:Configuration组件(转)
    asp.net core源码飘香:Logging组件(转)
    基于C#.NET的高端智能化网络爬虫(下)(转)
    基于C#.NET的高端智能化网络爬虫(转)
    30分钟掌握 C#7(转)
    30分钟掌握 C#6(转)
  • 原文地址:https://www.cnblogs.com/yifanrensheng/p/13583446.html
Copyright © 2011-2022 走看看