zoukankan      html  css  js  c++  java
  • 【文本分类-中文】textCNN

    目录

    1. 概述
    2. 数据集合
    3. 代码
    4. 结果展示

    一、概述

    在英文分类的基础上,再看看中文分类的,是一种10分类问题(体育,科技,游戏,财经,房产,家居等)的处理。

    二、数据集合

    数据集为新闻,总共有四个数据文件,在/data/cnews目录下,包括内容如下图所示测试集,训练集和验证集,和单词表(最后的单词表cnews.vocab.txt可以不要,因为训练可以自动产生)。数据格式:前面为类别,后面为描述内容。

    训练数据地址:链接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取码:2r04

    其中训练集的格式:

    vocab.txt的格式:每个字一行,其中前面加上PAD。

    三、代码

    3.1 数据采集cnews_loader.py

        1     # coding: utf-8
        2     import sys
        3     from collections import Counter
        4     import numpy as np
        5     import tensorflow.contrib.keras as kr
        6     
        7     if sys.version_info[0] > 2:
        8         is_py3 = True
        9     else:
       10         reload(sys)
       11         sys.setdefaultencoding("utf-8")
       12         is_py3 = False
       13     
       14     def native_word(word, encoding='utf-8'):
       15         """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
       16         if not is_py3:
       17             return word.encode(encoding)
       18         else:
       19             return word
       20     
       21     def native_content(content):
       22         if not is_py3:
       23             return content.decode('utf-8')
       24         else:
       25             return content
       26     
       27     def open_file(filename, mode='r'):
       28         """
       29         常用文件操作,可在python2和python3间切换.
       30         mode: 'r' or 'w' for read or write
       31         """
       32         if is_py3:
       33             return open(filename, mode, encoding='utf-8', errors='ignore')
       34         else:
       35             return open(filename, mode)
       36     
       37     def read_file(filename):
       38         """读取文件数据"""
       39         contents, labels = [], []
       40         with open_file(filename) as f:
       41             for line in f:
       42                 try:
       43                     label, content = line.strip().split('	')
       44                     if content:
       45                         contents.append(list(native_content(content)))
       46                         labels.append(native_content(label))
       47                 except:
       48                     pass
       49         return contents, labels
       50     
       51     def build_vocab(train_dir, vocab_dir, vocab_size=5000):
       52         """根据训练集构建词汇表,存储"""
       53         data_train, _ = read_file(train_dir)
       54         all_data = []
       55         for content in data_train:
       56             all_data.extend(content)
       57         counter = Counter(all_data)
       58         count_pairs = counter.most_common(vocab_size - 1)
       59         words, _ = list(zip(*count_pairs))
       60         # 添加一个 <PAD> 来将所有文本pad为同一长度
       61         words = ['<PAD>'] + list(words)
       62         open_file(vocab_dir, mode='w').write('
    '.join(words) + '
    ')
       63     
       64     def read_vocab(vocab_dir):
       65         """读取词汇表"""
       66         # words = open_file(vocab_dir).read().strip().split('
    ')
       67         with open_file(vocab_dir) as fp:
       68             # 如果是py2 则每个值都转化为unicode
       69             words = [native_content(_.strip()) for _ in fp.readlines()]
       70         word_to_id = dict(zip(words, range(len(words))))
       71         return words, word_to_id
       72     
       73     def read_category():
       74         """读取分类目录,固定"""
       75         categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
       76         categories = [native_content(x) for x in categories]
       77         cat_to_id = dict(zip(categories, range(len(categories))))
       78         return categories, cat_to_id
       79     
       80     def to_words(content, words):
       81         """将id表示的内容转换为文字"""
       82         return ''.join(words[x] for x in content)
       83     
       84     def process_file(filename, word_to_id, cat_to_id, max_length=600):
       85         """将文件转换为id表示"""
       86         contents, labels = read_file(filename)
       87     
       88         data_id, label_id = [], []
       89         for i in range(len(contents)):
       90             data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])
       91             label_id.append(cat_to_id[labels[i]])
       92         # 使用keras提供的pad_sequences来将文本pad为固定长度
       93         x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
       94         y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示
       95         return x_pad, y_pad
       96     
       97     def batch_iter(x, y, batch_size=64):
       98         """生成批次数据"""
       99         data_len = len(x)
      100         num_batch = int((data_len - 1) / batch_size) + 1
      101         indices = np.random.permutation(np.arange(data_len))
      102         x_shuffle = x[indices]
      103         y_shuffle = y[indices]
      104     
      105         for i in range(num_batch):
      106             start_id = i * batch_size
      107             end_id = min((i + 1) * batch_size, data_len)
      108             yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

    3.2 模型搭建cnn_model.py

    定义训练的参数,TextCNN()模型

        1     # coding: utf-8
        2     import tensorflow as tf
        3     class TCNNConfig(object):
        4         """CNN配置参数"""
        5         embedding_dim = 64  # 词向量维度
        6         seq_length = 600  # 序列长度
        7         num_classes = 10  # 类别数
        8         num_filters = 256  # 卷积核数目
        9         kernel_size = 5  # 卷积核尺寸
       10         vocab_size = 5000  # 词汇表达小
       11         hidden_dim = 128  # 全连接层神经元
       12         dropout_keep_prob = 0.5  # dropout保留比例
       13         learning_rate = 1e-3  # 学习率
       14         batch_size = 64  # 每批训练大小
       15         num_epochs = 10  # 总迭代轮次
       16         print_per_batch = 100  # 每多少轮输出一次结果
       17         save_per_batch = 10  # 每多少轮存入tensorboard
       18     
       19     class TextCNN(object):
       20         """文本分类,CNN模型"""
       21         def __init__(self, config):
       22             self.config = config
       23             # 三个待输入的数据
       24             self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
       25             self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
       26             self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
       27             self.cnn()
       28     
       29         def cnn(self):
       30             """CNN模型"""
       31             # 词向量映射
       32             with tf.device('/cpu:0'):
       33                 embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
       34                 embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
       35     
       36             with tf.name_scope("cnn"):
       37                 # CNN layer
       38                 conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
       39                 # global max pooling layer
       40                 gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
       41     
       42             with tf.name_scope("score"):
       43                 # 全连接层,后面接dropout以及relu激活
       44                 fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
       45                 fc = tf.contrib.layers.dropout(fc, self.keep_prob)
       46                 fc = tf.nn.relu(fc)
       47                 # 分类器
       48                 self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
       49                 self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别
       50     
       51             with tf.name_scope("optimize"):
       52                 # 损失函数,交叉熵
       53                 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
       54                 self.loss = tf.reduce_mean(cross_entropy)
       55                 # 优化器
       56                 self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
       57     
       58             with tf.name_scope("accuracy"):
       59                 # 准确率
       60                 correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
       61                 self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    3.3 运行代码run_cnn.py

      1 #!/usr/bin/python
      2 # -*- coding: utf-8 -*-
      3 from __future__ import print_function
      4 import os
      5 import sys
      6 import time
      7 from datetime import timedelta
      8 import numpy as np
      9 import tensorflow as tf
     10 from sklearn import metrics
     11 from cnn_model import TCNNConfig, TextCNN
     12 from  cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
     13 
     14 base_dir = '../data/cnews'
     15 train_dir = os.path.join(base_dir, 'cnews.train.txt')
     16 test_dir = os.path.join(base_dir, 'cnews.test.txt')
     17 val_dir = os.path.join(base_dir, 'cnews.val.txt')
     18 vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
     19 save_dir = 'checkpoints/textcnn'
     20 save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
     21 
     22 def get_time_dif(start_time):
     23     """获取已使用时间"""
     24     end_time = time.time()
     25     time_dif = end_time - start_time
     26     return timedelta(seconds=int(round(time_dif)))
     27 
     28 def feed_data(x_batch, y_batch, keep_prob):
     29     feed_dict = {
     30         model.input_x: x_batch,
     31         model.input_y: y_batch,
     32         model.keep_prob: keep_prob
     33     }
     34     return feed_dict
     35 
     36 def evaluate(sess, x_, y_):
     37     """评估在某一数据上的准确率和损失"""
     38     data_len = len(x_)
     39     batch_eval = batch_iter(x_, y_, 128)
     40     total_loss = 0.0
     41     total_acc = 0.0
     42     for x_batch, y_batch in batch_eval:
     43         batch_len = len(x_batch)
     44         feed_dict = feed_data(x_batch, y_batch, 1.0)
     45         loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
     46         total_loss += loss * batch_len
     47         total_acc += acc * batch_len
     48     return total_loss / data_len, total_acc / data_len
     49 
     50 def train():
     51     print("Configuring TensorBoard and Saver...")
     52     # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
     53     tensorboard_dir = '../tensorboard/textcnn'
     54     if not os.path.exists(tensorboard_dir):
     55         os.makedirs(tensorboard_dir)
     56     tf.summary.scalar("loss", model.loss)
     57     tf.summary.scalar("accuracy", model.acc)
     58     merged_summary = tf.summary.merge_all()
     59     writer = tf.summary.FileWriter(tensorboard_dir)
     60 
     61     # 配置 Saver
     62     saver = tf.train.Saver()
     63     if not os.path.exists(save_dir):
     64         os.makedirs(save_dir)
     65 
     66     print("Loading training and validation data...")
     67     # 载入训练集与验证集
     68     start_time = time.time()
     69     x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
     70     x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
     71     time_dif = get_time_dif(start_time)
     72     print("Time usage:", time_dif)
     73 
     74     # 创建session
     75     session = tf.Session()
     76     session.run(tf.global_variables_initializer())
     77     writer.add_graph(session.graph)
     78 
     79     print('Training and evaluating...')
     80     start_time = time.time()
     81     total_batch = 0  # 总批次
     82     best_acc_val = 0.0  # 最佳验证集准确率
     83     last_improved = 0  # 记录上一次提升批次
     84     require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练
     85 
     86     flag = False
     87     for epoch in range(config.num_epochs):
     88         print('Epoch:', epoch + 1)
     89         batch_train = batch_iter(x_train, y_train, config.batch_size)
     90         for x_batch, y_batch in batch_train:
     91             feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
     92             #print("x_batch is {}".format(x_batch.shape))
     93             if total_batch % config.save_per_batch == 0:
     94                 # 每多少轮次将训练结果写入tensorboard scalar
     95                 s = session.run(merged_summary, feed_dict=feed_dict)
     96                 writer.add_summary(s, total_batch)
     97             if total_batch % config.print_per_batch == 0:
     98                 # 每多少轮次输出在训练集和验证集上的性能
     99                 feed_dict[model.keep_prob] = 1.0
    100                 loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
    101                 loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
    102                 if acc_val > best_acc_val:
    103                     # 保存最好结果
    104                     best_acc_val = acc_val
    105                     last_improved = total_batch
    106                     saver.save(sess=session, save_path=save_path)
    107                     improved_str = '*'
    108                 else:
    109                     improved_str = ''
    110                 time_dif = get_time_dif(start_time)
    111                 msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' 
    112                       + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
    113                 print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
    114 
    115             session.run(model.optim, feed_dict=feed_dict)  # 运行优化
    116             total_batch += 1
    117 
    118             if total_batch - last_improved > require_improvement:
    119                 # 验证集正确率长期不提升,提前结束训练
    120                 print("No optimization for a long time, auto-stopping...")
    121                 flag = True
    122                 break  # 跳出循环
    123         if flag:  # 同上
    124             break
    125 
    126 def test():
    127     print("Loading test data...")
    128     start_time = time.time()
    129     x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
    130 
    131     session = tf.Session()
    132     session.run(tf.global_variables_initializer())
    133     saver = tf.train.Saver()
    134     saver.restore(sess=session, save_path=save_path)  # 读取保存的模型
    135 
    136     print('Testing...')
    137     loss_test, acc_test = evaluate(session, x_test, y_test)
    138     msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
    139     print(msg.format(loss_test, acc_test))
    140 
    141     batch_size = 128
    142     data_len = len(x_test)
    143     num_batch = int((data_len - 1) / batch_size) + 1
    144 
    145     y_test_cls = np.argmax(y_test, 1)
    146     y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
    147     for i in range(num_batch):  # 逐批次处理
    148         start_id = i * batch_size
    149         end_id = min((i + 1) * batch_size, data_len)
    150         feed_dict = {
    151             model.input_x: x_test[start_id:end_id],
    152             model.keep_prob: 1.0
    153         }
    154         y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
    155 
    156     # 评估
    157     print("Precision, Recall and F1-Score...")
    158     print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
    159 
    160     # 混淆矩阵
    161     print("Confusion Matrix...")
    162     cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
    163     print(cm)
    164 
    165     time_dif = get_time_dif(start_time)
    166     print("Time usage:" , time_dif)
    167 
    168 if __name__ == '__main__':
    169     
    170     config = TCNNConfig()
    171     if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建,这里存在,因此不用重建
    172         build_vocab(train_dir, vocab_dir, config.vocab_size)
    173     categories, cat_to_id = read_category()
    174     words, word_to_id = read_vocab(vocab_dir)
    175     config.vocab_size = len(words)
    176     model = TextCNN(config)
    177     option='train'
    178     if option == 'train':
    179         train()
    180     else:
    181         test()

    3.4 预测predict.py

        1     # coding: utf-8
        2     from __future__ import print_function
        3     import os
        4     import tensorflow as tf
        5     import tensorflow.contrib.keras as kr
        6     from cnn_model import TCNNConfig, TextCNN
        7     from cnews_loader import read_category, read_vocab
        8     try:
        9         bool(type(unicode))
       10     except NameError:
       11         unicode = str
       12     
       13     base_dir = '../data/cnews'
       14     vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
       15     save_dir = '../checkpoints/textcnn'
       16     save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
       17     
       18     class CnnModel:
       19         def __init__(self):
       20             self.config = TCNNConfig()
       21             self.categories, self.cat_to_id = read_category()
       22             self.words, self.word_to_id = read_vocab(vocab_dir)
       23             self.config.vocab_size = len(self.words)
       24             self.model = TextCNN(self.config)
       25             self.session = tf.Session()
       26             self.session.run(tf.global_variables_initializer())
       27             saver = tf.train.Saver()
       28             saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
       29     
       30         def predict(self, message):
       31             # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
       32             content = unicode(message)
       33             data = [self.word_to_id[x] for x in content if x in self.word_to_id]
       34     
       35             feed_dict = {
       36                 self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
       37                 self.model.keep_prob: 1.0
       38             }
       39     
       40             y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
       41             return self.categories[y_pred_cls[0]]
       42     
       43     if __name__ == '__main__':
       44         cnn_model = CnnModel()
       45         test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机',
       46                      '热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']
       47         for i in test_demo:
       48             print(cnn_model.predict(i))

    四、结果展示

       

       

    相关代码可见:https://github.com/yifanhunter/textClassifier_chinese

       

  • 相关阅读:
    C# 文件类的操作---删除
    C#实现Zip压缩解压实例
    UVALIVE 2431 Binary Stirling Numbers
    UVA 10570 meeting with aliens
    UVA 306 Cipher
    UVA 10994 Simple Addition
    UVA 696 How Many Knights
    UVA 10205 Stack 'em Up
    UVA 11125 Arrange Some Marbles
    UVA 10912 Simple Minded Hashing
  • 原文地址:https://www.cnblogs.com/yifanrensheng/p/13583443.html
Copyright © 2011-2022 走看看