zoukankan      html  css  js  c++  java
  • 小学数学应用题自动解题baseline

    https://github.com/bojone/ape210k_baseline

    py2.7+requirements.txt

    absl-py==0.11.0
    astor==0.8.1
    astroid==1.5.3
    backports.functools-lru-cache==1.5
    backports.weakref==1.0.post1
    bert4keras==0.8.8
    certifi==2020.6.20
    configparser==4.0.2
    enum34==1.1.10
    funcsigs==1.0.2
    futures==3.3.0
    gast==0.4.0
    google-pasta==0.2.0
    grpcio==1.33.2
    h5py==2.10.0
    isort==4.3.21
    Keras==2.3.1
    Keras-Applications==1.0.8
    Keras-Preprocessing==1.1.2
    lazy-object-proxy==1.4.3
    Markdown==3.1.1
    mock==3.0.5
    mpmath==1.1.0
    numpy==1.16.0
    pandas==0.20.3
    protobuf==3.13.0
    pylint==1.7.2
    python-dateutil==2.8.1
    pytz==2020.1
    PyYAML==5.3.1
    scipy==1.2.3
    six==1.15.0
    sympy==1.1.1
    tensorboard==1.14.0
    tensorflow==1.14.0
    tensorflow-estimator==1.14.0
    termcolor==1.1.0
    Werkzeug==1.0.1
    wrapt==1.12.1
    tqdm==4.50.2
    #! -*- coding: utf-8 -*-
    # 用Seq2Seq做小学数学应用题
    # 数据集为ape210k:https://github.com/Chenny0808/ape210k
    # Base版准确率为70%+,Large版准确率为73%+
    # 实测环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.8.8
    # 介绍链接:https://kexue.fm/archives/7809
    
    from __future__ import division
    import json, re
    import numpy as np
    import pandas as pd
    from tqdm import tqdm
    from bert4keras.backend import keras, K
    from bert4keras.layers import Loss
    from bert4keras.models import build_transformer_model
    from bert4keras.tokenizers import Tokenizer, load_vocab
    from bert4keras.optimizers import Adam
    from bert4keras.snippets import sequence_padding, open
    from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
    from keras.models import Model
    from sympy import Integer
    
    # 基本参数
    maxlen = 192
    batch_size = 32
    epochs = 100
    
    # bert配置
    config_path = '/Users/war/Downloads/uer/mixed_corpus_bert_base_model/bert_config.json'
    checkpoint_path = '/Users/war/Downloads/uer/mixed_corpus_bert_base_model/bert_model.ckpt'
    dict_path = '/Users/war/Downloads/uer/mixed_corpus_bert_base_model/vocab.txt'
    
    
    def is_equal(a, b):
        """比较两个结果是否相等
        """
        a = round(float(a), 6)
        b = round(float(b), 6)
        return a == b
    
    
    def remove_bucket(equation):
        """去掉冗余的括号
        """
        l_buckets, buckets = [], []
        for i, c in enumerate(equation):
            if c == '(':
                l_buckets.append(i)
            elif c == ')':
                buckets.append((l_buckets.pop(), i))
        eval_equation = eval(equation)
        for l, r in buckets:
            new_equation = '%s %s %s' % (
                equation[:l], equation[l + 1:r], equation[r + 1:]
            )
            try:
                if is_equal(eval(new_equation.replace(' ', '')), eval_equation):
                    equation = new_equation
            except:
                pass
        return equation.replace(' ', '')
    
    
    def load_data(filename):
        """读取训练数据,并做一些标准化,保证equation是可以eval的
        参考:https://kexue.fm/archives/7809
        """
        D = []
        for l in open(filename):
            l = json.loads(l)
            question, equation, answer = l['original_text'], l['equation'], l['ans']
            # 处理带分数
            question = re.sub('(d+)((d+/d+))', '(\1+\2)', question)
            equation = re.sub('(d+)((d+/d+))', '(\1+\2)', equation)
            answer = re.sub('(d+)((d+/d+))', '(\1+\2)', answer)
            equation = re.sub('(d+)(', '\1+(', equation)
            answer = re.sub('(d+)(', '\1+(', answer)
            # 分数去括号
            question = re.sub('((d+/d+))', '\1', question)
            # 处理百分数
            equation = re.sub('([.d]+)%', '(\1/100)', equation)
            answer = re.sub('([.d]+)%', '(\1/100)', answer)
            # 冒号转除号、剩余百分号处理
            equation = equation.replace(':', '/').replace('%', '/100')
            answer = answer.replace(':', '/').replace('%', '/100')
            if equation[:2] == 'x=':
                equation = equation[2:]
            try:
                if is_equal(eval(equation), eval(answer)):
                    D.append((question, remove_bucket(equation), answer))
            except:
                continue
        return D
    
    
    # 加载数据集
    train_data = load_data('/Users/war/Downloads/ape210k-master/data/train.ape.json')
    valid_data = load_data('/Users/war/Downloads/ape210k-master/data/valid.ape.json')
    test_data = load_data('/Users/war/Downloads/ape210k-master/data/test.ape.json')
    
    # 加载并精简词表,建立分词器
    token_dict, keep_tokens = json.load(open('/Users/war/Downloads/苏--训练好的模型权重/token_dict_keep_tokens.json'))
    # token_dict, keep_tokens = load_vocab(
    #     dict_path=dict_path,
    #     simplified=True,
    #     startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
    # )
    tokenizer = Tokenizer(token_dict, do_lower_case=True)
    
    
    class data_generator(DataGenerator):
        """数据生成器
        """
        def __iter__(self, random=False):
            batch_token_ids, batch_segment_ids = [], []
            for is_end, (question, equation, answer) in self.sample(random):
                token_ids, segment_ids = tokenizer.encode(
                    question, equation, maxlen=maxlen
                )
                batch_token_ids.append(token_ids)
                batch_segment_ids.append(segment_ids)
                if len(batch_token_ids) == self.batch_size or is_end:
                    batch_token_ids = sequence_padding(batch_token_ids)
                    batch_segment_ids = sequence_padding(batch_segment_ids)
                    yield [batch_token_ids, batch_segment_ids], None
                    batch_token_ids, batch_segment_ids = [], []
    
    
    class CrossEntropy(Loss):
        """交叉熵作为loss,并mask掉输入部分
        """
        def compute_loss(self, inputs, mask=None):
            y_true, y_mask, y_pred = inputs
            y_true = y_true[:, 1:]  # 目标token_ids
            y_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分
            y_pred = y_pred[:, :-1]  # 预测序列,错开一位
            loss = K.sparse_categorical_crossentropy(y_true, y_pred)
            loss = K.sum(loss * y_mask) / K.sum(y_mask)
            return loss
    
    
    model = build_transformer_model(
        config_path,
        checkpoint_path,
        application='unilm',
        keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
    )
    
    output = CrossEntropy(2)(model.inputs + model.outputs)
    
    model = Model(model.inputs, output)
    model.compile(optimizer=Adam(2e-5))
    model.summary()
    
    
    class AutoSolve(AutoRegressiveDecoder):
        """seq2seq解码器
        """
        @AutoRegressiveDecoder.wraps(default_rtype='probas')
        def predict(self, inputs, output_ids, states):
            token_ids, segment_ids = inputs
            token_ids = np.concatenate([token_ids, output_ids], 1)
            segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
            return model.predict([token_ids, segment_ids])[:, -1]
    
        def generate(self, text, topk=1):
            token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
            output_ids = self.beam_search([token_ids, segment_ids],
                                          topk)  # 基于beam search
            return tokenizer.decode(output_ids).replace(' ', '')
    
    
    autosolve = AutoSolve(start_id=None, end_id=tokenizer._token_end_id, maxlen=64)
    
    
    class Evaluator(keras.callbacks.Callback):
        """评估与保存
        """
        def __init__(self):
            self.best_acc = 0.
    
        def on_epoch_end(self, epoch, logs=None):
            metrics = self.evaluate(valid_data)  # 评测模型
            if metrics['acc'] >= self.best_acc:
                self.best_acc = metrics['acc']
                model.save_weights('./best_model.weights')  # 保存模型
            metrics['best_acc'] = self.best_acc
            print('valid_data:', metrics)
    
        def evaluate(self, data, topk=1):
            total, right = 0.0, 0.0
            for question, equation, answer in tqdm(data):
                total += 1
                pred_equation = autosolve.generate(question, topk)
                try:
                    right += int(is_equal(eval(pred_equation), eval(answer)))
                except:
                    pass
            return {'acc': right / total}
    
    
    def predict(in_file, out_file, topk=1):
        """输出预测结果到文件
        该函数主要为比赛 https://www.datafountain.cn/competitions/467 所写,
        主要是读取该比赛的测试集,然后预测equation,并且根据不同的问题输出不同格式的答案,
        out_file可以直接提交到线上评测,线上准确率可以达到38%+"""
        fw = open(out_file, 'w', encoding='utf-8')
        raw_data = pd.read_csv(in_file, header=None, encoding='utf-8')
        for i, question in tqdm(raw_data.values):
            question = re.sub('(d+)_(d+/d+)', '(\1+\2)', question)
            pred_equation = autosolve.generate(question, topk)
            if '.' not in pred_equation:
                pred_equation = re.sub('([d]+)', 'Integer(\1)', pred_equation)
            try:
                pred_answer = eval(pred_equation)
            except:
                pred_answer = np.random.choice(21) + 1
            if '.' in pred_equation:
                if u'百分之几' in question:
                    pred_answer = pred_answer * 100
                pred_answer = round(pred_answer, 2)
                if int(pred_answer) == pred_answer:
                    pred_answer = int(pred_answer)
                if (
                    re.findall(u'多少[辆|人|个|只|箱|包本|束|头|盒|张]', question) or
                    re.findall(u'几[辆|人|个|只|箱|包|本|束|头|盒|张]', question)
                ):
                    if re.findall(u'至少|最少', question):
                        pred_answer = np.ceil(pred_answer)
                    elif re.findall(u'至多|最多', question):
                        pred_answer = np.floor(pred_answer)
                    else:
                        pred_answer = np.ceil(pred_answer)
                    pred_answer = int(pred_answer)
                pred_answer = str(pred_answer)
                if u'百分之几' in question:
                    pred_answer = pred_answer + '%'
            else:
                pred_answer = str(pred_answer)
                if '/' in pred_answer:
                    if re.findall('d+/d+', question):
                        a, b = pred_answer.split('/')
                        a, b = int(a), int(b)
                        if a > b:
                            pred_answer = '%s_%s/%s' % (a // b, a % b, b)
                    else:
                        if re.findall(u'至少|最少', question):
                            pred_answer = np.ceil(eval(pred_answer))
                        elif re.findall(u'至多|最多', question):
                            pred_answer = np.floor(eval(pred_answer))
                        else:
                            pred_answer = np.ceil(eval(pred_answer))
                        pred_answer = str(int(pred_answer))
            fw.write(str(i) + ',' + pred_answer + '
    ')
            fw.flush()
        fw.close()
    
    
    if __name__ == '__main__':
    
        evaluator = Evaluator()
        train_generator = data_generator(train_data, batch_size)
    
        model.fit_generator(
            train_generator.forfit(),
            steps_per_epoch=len(train_generator),
            epochs=epochs,
            callbacks=[evaluator]
        )
        Input = load_data("/Users/war/Downloads/test.csv")
        Output = load_data("/Users/war/Downloads/submit_example.csv")
        predict(Input, Output)
    else:
        Input = load_data("/Users/war/Downloads/test.csv")
        Output = load_data("/Users/war/Downloads/submit_example.csv")
        model.load_weights('/Users/war/Downloads/苏--训练好的模型权重/best_model.weights')
        predict(Input, Output)
  • 相关阅读:
    实用机器学习 跟李沐学AI
    Explicitly drop temp table or let SQL Server handle it
    dotnettransformxdt and FatAntelope
    QQ拼音输入法 禁用模糊音
    (技术八卦)Java VS RoR
    Ruby on rails开发从头来(windows)(七)创建在线购物页面
    Ruby on rails开发从头来(windows)(十三)订单(Order)
    Ruby on rails开发从头来(windows)(十一)订单(Order)
    新员工自缢身亡,华为又站到了风口浪尖
    死亡汽油弹(Napalm Death)乐队的视频和来中国演出的消息
  • 原文地址:https://www.cnblogs.com/war1111/p/13962356.html
Copyright © 2011-2022 走看看