zoukankan      html  css  js  c++  java
  • 利用CNN进行多分类的文档分类

    # coding: utf-8
    
    import tensorflow as tf
    
    
    class TCNNConfig(object):
        """CNN配置参数"""
    
        embedding_dim = 20  # 词向量维度
        seq_length = 100  # 序列长度
        num_classes = 73  # 类别数
        num_filters = 256  # 卷积核数目
        kernel_size = 5  # 卷积核尺寸
        vocab_size = 5000  # 词汇表达小
    
        hidden_dim = 128  # 全连接层神经元
    
        dropout_keep_prob = 0.8  # dropout保留比例
        learning_rate = 0.001  # 学习率
    
        batch_size = 128  # 每批训练大小
        num_epochs = 5  # 总迭代轮次
    
        print_per_batch = 100  # 每多少轮输出一次结果
        save_per_batch = 10  # 每多少轮存入tensorboard
    
    
    class TextCNN(object):
        """文本分类,CNN模型"""
    
        def __init__(self, config):
            self.config = config
    
            # 三个待输入的数据
            self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
            self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
            self.cnn()
    
        def cnn(self):
            """CNN模型"""
            # 词向量映射
            with tf.device('/cpu:0'):
                embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
                embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)
    
            with tf.name_scope("cnn"):
                # CNN layer
                conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
                # global max pooling layer
                gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')
    
            with tf.name_scope("score"):
                # 全连接层,后面接dropout以及relu激活
                fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
                fc = tf.contrib.layers.dropout(fc, self.keep_prob)
                fc = tf.nn.relu(fc)
    
                # 分类器
                self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
                self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别
    
            with tf.name_scope("optimize"):
                # 损失函数,交叉熵
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
                self.loss = tf.reduce_mean(cross_entropy)
                # 优化器
                self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)
    
            with tf.name_scope("accuracy"):
                # 准确率
                correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
                self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    #!/usr/bin/python
    # -*- coding: utf-8 -*-
    
    from __future__ import print_function
    
    import os
    import sys
    import time
    from datetime import timedelta
    
    import numpy as np
    import tensorflow as tf
    from sklearn import metrics
    
    from cnn_model import TCNNConfig, TextCNN
    from data.cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
    
    base_dir = 'data/'
    train_dir = os.path.join(base_dir, 'train.txt')
    test_dir = os.path.join(base_dir, 'test.txt')
    val_dir = os.path.join(base_dir, 'test.txt')
    vocab_dir = os.path.join(base_dir, 'bbb.txt')
    
    save_dir = 'checkpoints/textcnn'
    save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
    
    
    def get_time_dif(start_time):
        """获取已使用时间"""
        end_time = time.time()
        time_dif = end_time - start_time
        return timedelta(seconds=int(round(time_dif)))
    
    
    def feed_data(x_batch, y_batch, keep_prob):
        feed_dict = {
            model.input_x: x_batch,
            model.input_y: y_batch,
            model.keep_prob: keep_prob
        }
        return feed_dict
    
    
    def evaluate(sess, x_, y_):
        """评估在某一数据上的准确率和损失"""
        data_len = len(x_)
        batch_eval = batch_iter(x_, y_, 128)
        total_loss = 0.0
        total_acc = 0.0
        for x_batch, y_batch in batch_eval:
            batch_len = len(x_batch)
            feed_dict = feed_data(x_batch, y_batch, 1.0)
            loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
            total_loss += loss * batch_len
            total_acc += acc * batch_len
    
        return total_loss / data_len, total_acc / data_len
    
    
    def train():
        print("Configuring TensorBoard and Saver...")
        # 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖
        tensorboard_dir = 'tensorboard/textcnn'
        if not os.path.exists(tensorboard_dir):
            os.makedirs(tensorboard_dir)
    
        tf.summary.scalar("loss", model.loss)
        tf.summary.scalar("accuracy", model.acc)
        merged_summary = tf.summary.merge_all()
        writer = tf.summary.FileWriter(tensorboard_dir)
    
        # 配置 Saver
        saver = tf.train.Saver()
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
    
        print("Loading training and validation data...")
        # 载入训练集与验证集
        start_time = time.time()
        x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)
        x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)
        time_dif = get_time_dif(start_time)
        print("Time usage:", time_dif)
    
        # 创建session
        session = tf.Session()
        session.run(tf.global_variables_initializer())
        writer.add_graph(session.graph)
    
        print('Training and evaluating...')
        start_time = time.time()
        total_batch = 0  # 总批次
        best_acc_val = 0.0  # 最佳验证集准确率
        last_improved = 0  # 记录上一次提升批次
        require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练
    
        flag = False
        for epoch in range(config.num_epochs):
            print('Epoch:', epoch + 1)
            batch_train = batch_iter(x_train, y_train, config.batch_size)
            for x_batch, y_batch in batch_train:
                feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)
    
                if total_batch % config.save_per_batch == 0:
                    # 每多少轮次将训练结果写入tensorboard scalar
                    s = session.run(merged_summary, feed_dict=feed_dict)
                    writer.add_summary(s, total_batch)
    
                if total_batch % config.print_per_batch == 0:
                    # 每多少轮次输出在训练集和验证集上的性能
                    feed_dict[model.keep_prob] = 1.0
                    loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)
                    loss_val, acc_val = evaluate(session, x_val, y_val)  # todo
    
                    if acc_val > best_acc_val:
                        # 保存最好结果
                        best_acc_val = acc_val
                        last_improved = total_batch
                        saver.save(sess=session, save_path=save_path)
                        improved_str = '*'
                    else:
                        improved_str = ''
    
                    time_dif = get_time_dif(start_time)
                    msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' 
                          + ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'
                    print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))
    
                session.run(model.optim, feed_dict=feed_dict)  # 运行优化
                total_batch += 1
    
                if total_batch - last_improved > require_improvement:
                    # 验证集正确率长期不提升,提前结束训练
                    print("No optimization for a long time, auto-stopping...")
                    flag = True
                    break  # 跳出循环
            if flag:  # 同上
                break
    
    
    def test():
        print("Loading test data...")
        start_time = time.time()
        x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)
    
        session = tf.Session()
        session.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess=session, save_path=save_path)  # 读取保存的模型
    
        print('Testing...')
        loss_test, acc_test = evaluate(session, x_test, y_test)
        msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'
        print(msg.format(loss_test, acc_test))
    
        batch_size = 128
        data_len = len(x_test)
        num_batch = int((data_len - 1) / batch_size) + 1
    
        y_test_cls = np.argmax(y_test, 1)
        y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果
        for i in range(num_batch):  # 逐批次处理
            start_id = i * batch_size
            end_id = min((i + 1) * batch_size, data_len)
            feed_dict = {
                model.input_x: x_test[start_id:end_id],
                model.keep_prob: 1.0
            }
            y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)
    
        # 评估
        print("Precision, Recall and F1-Score...")
        print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))
    
        # 混淆矩阵
        print("Confusion Matrix...")
        cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)
        print(cm)
    
        time_dif = get_time_dif(start_time)
        print("Time usage:", time_dif)
    
    
    if __name__ == '__main__':
    
    
        print('Configuring CNN model...')
        config = TCNNConfig()
        if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建
            build_vocab(train_dir, vocab_dir, config.vocab_size)
        categories, cat_to_id = read_category()
        words, word_to_id = read_vocab(vocab_dir)
        config.vocab_size = len(words)
        model = TextCNN(config)
    
    
        # train()
        test()
    # coding: utf-8
    
    from __future__ import print_function
    
    import os
    import tensorflow as tf
    import tensorflow.contrib.keras as kr
    import time
    from run_cnn import  get_time_dif
    from cnn_model import TCNNConfig, TextCNN
    from data.cnews_loader import read_category, read_vocab
    
    
    base_dir = 'data/'
    vocab_dir = os.path.join(base_dir, 'bbb.txt')
    
    save_dir = 'checkpoints/textcnn'
    save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径
    
    
    class CnnModel:
        def __init__(self):
            self.config = TCNNConfig()
            self.categories, self.cat_to_id = read_category()
            self.words, self.word_to_id = read_vocab(vocab_dir)
            self.config.vocab_size = len(self.words)
            self.model = TextCNN(self.config)
    
            self.session = tf.Session()
            self.session.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型
    
        def predict(self, message):
            # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
            content = message
            data = [self.word_to_id[x] for x in content if x in self.word_to_id]
    
            feed_dict = {
                self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
                self.model.keep_prob: 1.0
            }
    
            y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)
            return self.categories[y_pred_cls[0]]
    
    
    if __name__ == '__main__':
        starttime = time.time()
        cnn_model = CnnModel()
        test_demo = [' 16-12-08 今年前11个月我国进出口总值21.83万亿元 ',
     '16-12-08 英国知识产权局局长一行访问黄埔海关(图',
     '16-12-08 厦门海关启动“互联网+自主报关”改革 ',
     '16-12-08 江门海关“宪法日” 普法到一线(图)',
     '16-12-08 27.5公斤“萌萌哒”果实种子闯关被截获(图)',
     '广州海关推动“主动披露” 体现执法“宽严相济”',
     '16-12-07 胡伟在湛江出席全国沿海沿边地区基层反走私综合治理现场会(图)',
     '16-12-07 锐意改革 高效服务 海关力助湛江书写蓝色经济梦想',
     '16-12-07 红其拉甫海关查获毒品海洛因4.8千克(图)']
        for i in test_demo:
            print(cnn_model.predict(i))
        print(get_time_dif(starttime))
  • 相关阅读:
    Handler使用总结(转)
    LR连接oracle数据库-lr_db_connect
    selenium2(WebDriver)环境搭建
    使用selenium控制滚动条(非整屏body)
    selenium-打开IE浏览器遇到问题记录
    使用re-sign.jar对apk进行重签名
    Robotium-无源码测试
    genymotion不能联网
    SQL 常用脚本
    todolist
  • 原文地址:https://www.cnblogs.com/yiduobaozhiblog1/p/9948014.html
Copyright © 2011-2022 走看看