zoukankan      html  css  js  c++  java
  • tensorflow版本的tansformer训练IWSLT数据集

    代码来源:https://github.com/Kyubyong/transformer

    1、git clone https://github.com/Kyubyong/transformer.git

    2、pip install sentencepiece

    3、下载数据集

    进入到tansformer目录下,输入:sh download.sh

    运行成功之后,会有这么一些文件:

    de-en.de.xml中内容大致是这个样子的:

    <?xml version="1.0" encoding="UTF-8"?>
    <mteval>
    <srcset setid="iwslt2016-dev2010" srclang="german">
    <doc docid="69" genre="lectures">
    <url>http://www.ted.com/talks/lang/de/wade_davis_on_endangered_cultures.html</url>
    <description>Mit atemberaubenden Fotos und Geschichten feiert der National Geographic- Forschungsreisende Wade Davis, die außergewöhnliche Vielfalt der Ureinwohner der Welt, welche in alarmierender Anzahl von unserem Planeten verschwinden.</description>
    <keywords>anthropology,culture,environment,film,global issues,language,photography</keywords>
    <talkid>69</talkid>
    <title>Wade Davis über gefährdete Kulturen</title>
    <seg id="1"> Wissen Sie, eines der großen Vernügen beim Reisen und eine der Freuden bei der ethnographischen Forschung ist, gemeinsam mit den Menschen zu leben, die sich noch an die alten Tage erinnern können. Die ihre Vergangenheit noch immer im Wind spüren, sie auf vom Regen geglätteten Steinen berühren, sie in den bitteren Blättern der Pflanzen schmecken. </seg>
    <seg id="2"> Einfach das Wissen, dass Jaguar-Schamanen noch immer jenseits der Milchstraße reisen oder die Bedeutung der Mythen der Ältesten der Inuit noch voller Bedeutung sind, oder dass im Himalaya die Buddhisten noch immer den Atem des Dharma verfolgen, bedeutet, sich die zentrale Offenbarung der Anthropologie ins Gedächtnis zu rufen, das ist der Gedanke, dass die Welt, in der wir leben, nicht in einem absoluten Sinn existiert, sondern nur als ein Modell der Realität, als eine Folge einer Gruppe von bestimmten Möglichkeiten der Anpassung die unsere Ahnen, wenngleich erfolgreich, vor vielen Generationen wählten. </seg>
    <seg id="3"> Und natürlich teilen wir alle dieselben Anpassungsnotwendigkeiten. </seg>
    <seg id="4"> Wir werden alle geboren. Wir bringen Kinder zur Welt. </seg>
    <seg id="5"> Wir durchlaufen Initiationsrituale. </seg>
    <seg id="6"> Wir müssen uns mit der unaufhaltsamen Trennung durch den Tod auseinandersetzen und somit sollte es uns nicht überraschen, dass wir alle singen, tanzen und und Kunst hervorbringen. </seg>
    <seg id="7"> Aber interessant ist der einzigartige Tonfall des Liedes, der Rhythmus des Tanzes in jeder Kultur. </seg>
    <seg id="8"> Dabei spielt es keine Rolle, ob es sich um die Penan in den Wäldern von Borneo handelt, oder die Voodoo-Akolythen in Haiti, oder die Krieger in der Kaisut-Wüste von Nordkenia, die Curanderos in den Anden, oder eine Karawanserei mitten in der Sahara. Dies ist zufällig der Kollege, mit dem ich vor einem Monat in die Wüste gereist bin. Oder selbst ein Yak-Hirte an den Hängen des Qomolangma, Everest, der Gottmutter der Welt. </seg>
    <seg id="9"> All diese Menschen lehren uns, dass es noch andere Existenzmöglichkeiten, andere Denkweisen, andere Wege zur Orientierung auf der Erde gibt. </seg>
    <seg id="10"> Und das ist eine Vorstellung, die, wenn man darüber nachdenkt, einen nur mit Hoffnung erfüllen kann. </seg>
    <seg id="11"> Zusammen bilden die unzähligen Kulturen der Welt ein Netz aus spirituellem und kulturellem Leben, das die Erde umhüllt und für das Wohl der Erde genauso wichtig ist, wie das biologische Lebensnetz, das man als Biosphäre kennt. </seg>
    <seg id="12"> Man kann sich dieses kulturelle Lebensnetz als eine Ethnosphäre vorstellen. Ethnosphäre kann dabei als die Gesamtsumme aller Gedanken und Träume, Mythen Ideen, Inspirationen und Intuitionen, die von der menschlichen Vorstellungskraft seit den Anfängen des Bewusstseins hervorgebracht wurden, definiert werden. </seg>
    <seg id="13"> Die Ethnosphäre ist das großartige Vermächtnis der Menschheit. </seg>
    <seg id="14"> Sie ist das Symbol all dessen, was wir sind und wozu wir als erstaunlich wissbegierige Spezies fähig sind. </seg>
    <seg id="15"> Und genauso wie die Biosphäre stark abgetragen wurde, geschah dies mit der Ethnosphäre -- nur mit noch größerer Geschwindigkeit. </seg>
    <seg id="16"> Kein Biologe würde zum Beispiel wagen zu behaupten, dass 50% oder mehr aller Arten kurz vor dem Aussterben sind, da es einfach nicht stimmt. Und doch, dieses -- das apokalyptischste Szenarium auf dem Gebiet der biologischen Vielfalt -- entspricht kaum dem, was uns als optimistischstes Szenarium auf dem Gebiet der kulturellen Vielfalt bekannt ist. </seg>
    <seg id="17"> Und der entscheidende Indikator dafür ist das Aussterben der Sprachen. </seg>

    4、创建训练集、验证集、测试集

    python prepro.py --vocab_size 8000

    部分运行结果:

    trainer_interface.cc(615) LOG(INFO) Saving model: iwslt2016/segmented/bpe.model
    trainer_interface.cc(626) LOG(INFO) Saving vocabs: iwslt2016/segmented/bpe.vocab
    INFO:root:# Load trained bpe model
    INFO:root:# Segment
    INFO:root:Let's see how segmented data look like
    train1: ▁David ▁G all o : ▁Das ▁ist ▁Bill ▁L ange . ▁Ich ▁bin ▁Da ve ▁G all o .
    
    train2: ▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o .
    
    eval1: ▁Als ▁ich ▁11 ▁Jahre ▁alt ▁war , ▁wurde ▁ich ▁eines ▁Morgen s ▁von ▁den ▁Kl ängen ▁h eller ▁Freude ▁ge we ckt .
    
    eval2: ▁When ▁I ▁was ▁11 , ▁I ▁remember ▁w aking ▁up ▁one ▁morning ▁to ▁the ▁sound ▁of ▁j oy ▁in ▁my ▁house .
    
    test1: ▁Als ▁ich ▁in ▁meinen ▁20 ern ▁war , ▁hatte ▁ich ▁meine ▁erste ▁Psych other ap ie - P at ient in .
    
    INFO:root:Done

    运行之后会有:

    prepro.py中的内容如下:

    # -*- coding: utf-8 -*-
    #/usr/bin/python3
    '''
    Feb. 2019 by kyubyong park.
    kbpark.linguist@gmail.com.
    https://www.github.com/kyubyong/transformer.
    
    Preprocess the iwslt 2016 datasets.
    '''
    
    import os
    import errno
    import sentencepiece as spm
    import re
    from hparams import Hparams
    import logging
    
    logging.basicConfig(level=logging.INFO)
    
    def prepro(hp):
        """Load raw data -> Preprocessing -> Segmenting with sentencepice
        hp: hyperparams. argparse.
        """
        logging.info("# Check if raw files exist")
        train1 = "iwslt2016/de-en/train.tags.de-en.de"
        train2 = "iwslt2016/de-en/train.tags.de-en.en"
        eval1 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.de.xml"
        eval2 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.en.xml"
        test1 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.de.xml"
        test2 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.en.xml"
        for f in (train1, train2, eval1, eval2, test1, test2):
            if not os.path.isfile(f):
                raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), f)
    
        logging.info("# Preprocessing")
        # train
        _prepro = lambda x:  [line.strip() for line in open(x, 'r').read().split("
    ") 
                          if not line.startswith("<")]
        prepro_train1, prepro_train2 = _prepro(train1), _prepro(train2)
        assert len(prepro_train1)==len(prepro_train2), "Check if train source and target files match."
    
        # eval
        _prepro = lambda x: [re.sub("<[^>]+>", "", line).strip() 
                         for line in open(x, 'r').read().split("
    ") 
                         if line.startswith("<seg id")]
        prepro_eval1, prepro_eval2 = _prepro(eval1), _prepro(eval2)
        assert len(prepro_eval1) == len(prepro_eval2), "Check if eval source and target files match."
    
        # test
        prepro_test1, prepro_test2 = _prepro(test1), _prepro(test2)
        assert len(prepro_test1) == len(prepro_test2), "Check if test source and target files match."
    
        logging.info("Let's see how preprocessed data look like")
        logging.info("prepro_train1:", prepro_train1[0])
        logging.info("prepro_train2:", prepro_train2[0])
        logging.info("prepro_eval1:", prepro_eval1[0])
        logging.info("prepro_eval2:", prepro_eval2[0])
        logging.info("prepro_test1:", prepro_test1[0])
        logging.info("prepro_test2:", prepro_test2[0])
    
        logging.info("# write preprocessed files to disk")
        os.makedirs("iwslt2016/prepro", exist_ok=True)
        def _write(sents, fname):
            with open(fname, 'w') as fout:
                fout.write("
    ".join(sents))
    
        _write(prepro_train1, "iwslt2016/prepro/train.de")
        _write(prepro_train2, "iwslt2016/prepro/train.en")
        _write(prepro_train1+prepro_train2, "iwslt2016/prepro/train")
        _write(prepro_eval1, "iwslt2016/prepro/eval.de")
        _write(prepro_eval2, "iwslt2016/prepro/eval.en")
        _write(prepro_test1, "iwslt2016/prepro/test.de")
        _write(prepro_test2, "iwslt2016/prepro/test.en")
    
        logging.info("# Train a joint BPE model with sentencepiece")
        os.makedirs("iwslt2016/segmented", exist_ok=True)
        train = '--input=iwslt2016/prepro/train --pad_id=0 --unk_id=1 
                 --bos_id=2 --eos_id=3
                 --model_prefix=iwslt2016/segmented/bpe --vocab_size={} 
                 --model_type=bpe'.format(hp.vocab_size)
        spm.SentencePieceTrainer.Train(train)
    
        logging.info("# Load trained bpe model")
        sp = spm.SentencePieceProcessor()
        sp.Load("iwslt2016/segmented/bpe.model")
    
        logging.info("# Segment")
        def _segment_and_write(sents, fname):
            with open(fname, "w") as fout:
                for sent in sents:
                    pieces = sp.EncodeAsPieces(sent)
                    fout.write(" ".join(pieces) + "
    ")
    
        _segment_and_write(prepro_train1, "iwslt2016/segmented/train.de.bpe")
        _segment_and_write(prepro_train2, "iwslt2016/segmented/train.en.bpe")
        _segment_and_write(prepro_eval1, "iwslt2016/segmented/eval.de.bpe")
        _segment_and_write(prepro_eval2, "iwslt2016/segmented/eval.en.bpe")
        _segment_and_write(prepro_test1, "iwslt2016/segmented/test.de.bpe")
    
        logging.info("Let's see how segmented data look like")
        print("train1:", open("iwslt2016/segmented/train.de.bpe",'r').readline())
        print("train2:", open("iwslt2016/segmented/train.en.bpe", 'r').readline())
        print("eval1:", open("iwslt2016/segmented/eval.de.bpe", 'r').readline())
        print("eval2:", open("iwslt2016/segmented/eval.en.bpe", 'r').readline())
        print("test1:", open("iwslt2016/segmented/test.de.bpe", 'r').readline())
    
    if __name__ == '__main__':
        hparams = Hparams()
        parser = hparams.parser
        hp = parser.parse_args()
        prepro(hp)
        logging.info("Done")

    train中部分内容如下:

    David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.
    Wir werden Ihnen einige Geschichten über das Meer in Videoform erzählen.
    Wir haben ein paar der unglaublichsten Aufnahmen der Titanic, die man je gesehen hat,, und wir werden Ihnen nichts davon zeigen.
    Die Wahrheit ist, dass die Titanic – obwohl sie alle Kinokassenrekorde bricht – nicht gerade die aufregendste Geschichte vom Meer ist.
    Ich denke, das Problem ist, dass wir das Meer für zu selbstverständlich halten.
    Wenn man darüber nachdenkt, machen die Ozeane 75 % des Planeten aus.
    Der Großteil der Erde ist Meerwasser.

    train.en.bpe中部分内容如下:

    ▁David ▁G all o : ▁This ▁is ▁Bill ▁L ange . ▁I ' m ▁Da ve ▁G all o .
    ▁And ▁we ' re ▁going ▁to ▁tell ▁you ▁some ▁stories ▁from ▁the ▁sea ▁here ▁in ▁video .
    ▁We ' ve ▁got ▁some ▁of ▁the ▁most ▁incredible ▁video ▁of ▁Tit an ic ▁that ' s ▁ever ▁been ▁seen , ▁and ▁we ' re ▁not ▁going ▁to ▁show ▁you ▁any ▁of ▁it .
    ▁The ▁truth ▁of ▁the ▁matter ▁is ▁that ▁the ▁Tit an ic ▁-- ▁even ▁though ▁it ' s ▁break ing ▁all ▁sorts ▁of ▁box ▁office ▁record s ▁-- ▁it ' s ▁not ▁the ▁most ▁exciting ▁story ▁from ▁the ▁sea .
    ▁And ▁the ▁problem , ▁I ▁think , ▁is ▁that ▁we ▁take ▁the ▁ocean ▁for ▁gr anted .
    ▁When ▁you ▁think ▁about ▁it , ▁the ▁oce ans ▁are ▁75 ▁percent ▁of ▁the ▁planet .
    ▁Most ▁of ▁the ▁planet ▁is ▁ocean ▁water .

    bpe.vocab部分内容如下:

    <pad>    0
    <unk>    0
    <s>    0
    </s>    0
    en    -0
    er    -1
    in    -2
    ▁t    -3
    ch    -4
    ▁a    -5
    ▁d    -6
    ▁w    -7
    ▁s    -8
    ▁th    -9
    nd    -10
    ie    -11
    es    -12

    5、train.py

    # -*- coding: utf-8 -*-
    #/usr/bin/python3
    '''
    Feb. 2019 by kyubyong park.
    kbpark.linguist@gmail.com.
    https://www.github.com/kyubyong/transformer
    '''
    import tensorflow as tf
    
    from model import Transformer
    from tqdm import tqdm
    from data_load import get_batch
    from utils import save_hparams, save_variable_specs, get_hypotheses, calc_bleu
    import os
    from hparams import Hparams
    import math
    import logging
    
    logging.basicConfig(level=logging.INFO)
    
    
    logging.info("# hparams")
    hparams = Hparams()
    parser = hparams.parser
    hp = parser.parse_args()
    save_hparams(hp, hp.logdir)
    
    logging.info("# Prepare train/eval batches")
    train_batches, num_train_batches, num_train_samples = get_batch(hp.train1, hp.train2,
                                                 hp.maxlen1, hp.maxlen2,
                                                 hp.vocab, hp.batch_size,
                                                 shuffle=True)
    eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2,
                                                 100000, 100000,
                                                 hp.vocab, hp.batch_size,
                                                 shuffle=False)
    
    # create a iterator of the correct shape and type
    iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)
    xs, ys = iter.get_next()
    
    train_init_op = iter.make_initializer(train_batches)
    eval_init_op = iter.make_initializer(eval_batches)
    
    logging.info("# Load model")
    m = Transformer(hp)
    loss, train_op, global_step, train_summaries = m.train(xs, ys)
    y_hat, eval_summaries = m.eval(xs, ys)
    # y_hat = m.infer(xs, ys)
    
    logging.info("# Session")
    saver = tf.train.Saver(max_to_keep=hp.num_epochs)
    with tf.Session() as sess:
        ckpt = tf.train.latest_checkpoint(hp.logdir)
        if ckpt is None:
            logging.info("Initializing from scratch")
            sess.run(tf.global_variables_initializer())
            save_variable_specs(os.path.join(hp.logdir, "specs"))
        else:
            saver.restore(sess, ckpt)
    
        summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)
    
        sess.run(train_init_op)
        total_steps = hp.num_epochs * num_train_batches
        _gs = sess.run(global_step)
        for i in tqdm(range(_gs, total_steps+1)):
            _, _gs, _summary = sess.run([train_op, global_step, train_summaries])
            epoch = math.ceil(_gs / num_train_batches)
            summary_writer.add_summary(_summary, _gs)
    
            if _gs and _gs % num_train_batches == 0:
                logging.info("epoch {} is done".format(epoch))
                _loss = sess.run(loss) # train loss
    
                logging.info("# test evaluation")
                _, _eval_summaries = sess.run([eval_init_op, eval_summaries])
                summary_writer.add_summary(_eval_summaries, _gs)
    
                logging.info("# get hypotheses")
                hypotheses = get_hypotheses(num_eval_batches, num_eval_samples, sess, y_hat, m.idx2token)
    
                logging.info("# write results")
                model_output = "iwslt2016_E%02dL%.2f" % (epoch, _loss)
                if not os.path.exists(hp.evaldir): os.makedirs(hp.evaldir)
                translation = os.path.join(hp.evaldir, model_output)
                with open(translation, 'w') as fout:
                    fout.write("
    ".join(hypotheses))
    
                logging.info("# calc bleu score and append it to translation")
                calc_bleu(hp.eval3, translation)
    
                logging.info("# save models")
                ckpt_name = os.path.join(hp.logdir, model_output)
                saver.save(sess, ckpt_name, global_step=_gs)
                logging.info("after training of {} epochs, {} has been saved.".format(epoch, ckpt_name))
    
                logging.info("# fall back to train mode")
                sess.run(train_init_op)
        summary_writer.close()
    
    
    logging.info("Done")

    我们一行行来看:

    首先调用了hparams.py中的函数:

    import argparse
    
    class Hparams:
        parser = argparse.ArgumentParser()
    
        # prepro
        parser.add_argument('--vocab_size', default=32000, type=int)
    
        # train
        ## files
        parser.add_argument('--train1', default='iwslt2016/segmented/train.de.bpe',
                                 help="german training segmented data")
        parser.add_argument('--train2', default='iwslt2016/segmented/train.en.bpe',
                                 help="english training segmented data")
        parser.add_argument('--eval1', default='iwslt2016/segmented/eval.de.bpe',
                                 help="german evaluation segmented data")
        parser.add_argument('--eval2', default='iwslt2016/segmented/eval.en.bpe',
                                 help="english evaluation segmented data")
        parser.add_argument('--eval3', default='iwslt2016/prepro/eval.en',
                                 help="english evaluation unsegmented data")
    
        ## vocabulary
        parser.add_argument('--vocab', default='iwslt2016/segmented/bpe.vocab',
                            help="vocabulary file path")
    
        # training scheme
        parser.add_argument('--batch_size', default=128, type=int)
        parser.add_argument('--eval_batch_size', default=128, type=int)
    
        parser.add_argument('--lr', default=0.0003, type=float, help="learning rate")
        parser.add_argument('--warmup_steps', default=4000, type=int)
        parser.add_argument('--logdir', default="log/1", help="log directory")
        parser.add_argument('--num_epochs', default=20, type=int)
        parser.add_argument('--evaldir', default="eval/1", help="evaluation dir")
    
        # model
        parser.add_argument('--d_model', default=512, type=int,
                            help="hidden dimension of encoder/decoder")
        parser.add_argument('--d_ff', default=2048, type=int,
                            help="hidden dimension of feedforward layer")
        parser.add_argument('--num_blocks', default=6, type=int,
                            help="number of encoder/decoder blocks")
        parser.add_argument('--num_heads', default=8, type=int,
                            help="number of attention heads")
        parser.add_argument('--maxlen1', default=100, type=int,
                            help="maximum length of a source sequence")
        parser.add_argument('--maxlen2', default=100, type=int,
                            help="maximum length of a target sequence")
        parser.add_argument('--dropout_rate', default=0.3, type=float)
        parser.add_argument('--smoothing', default=0.1, type=float,
                            help="label smoothing rate")
    
        # test
        parser.add_argument('--test1', default='iwslt2016/segmented/test.de.bpe',
                            help="german test segmented data")
        parser.add_argument('--test2', default='iwslt2016/prepro/test.en',
                            help="english test data")
        parser.add_argument('--ckpt', help="checkpoint file path")
        parser.add_argument('--test_batch_size', default=128, type=int)
        parser.add_argument('--testdir', default="test/1", help="test result dir")

    主要是一些超参数的设置。

    然后是data_load.py中用来加载数据集:

    # -*- coding: utf-8 -*-
    #/usr/bin/python3
    '''
    Feb. 2019 by kyubyong park.
    kbpark.linguist@gmail.com.
    https://www.github.com/kyubyong/transformer
    
    Note.
    if safe, entities on the source side have the prefix 1, and the target side 2, for convenience.
    For example, fpath1, fpath2 means source file path and target file path, respectively.
    '''
    import tensorflow as tf
    from utils import calc_num_batches
    
    def load_vocab(vocab_fpath):
        '''Loads vocabulary file and returns idx<->token maps
        vocab_fpath: string. vocabulary file path.
        Note that these are reserved
        0: <pad>, 1: <unk>, 2: <s>, 3: </s>
    
        Returns
        two dictionaries.
        '''
        vocab = [line.split()[0] for line in open(vocab_fpath, 'r').read().splitlines()]
        token2idx = {token: idx for idx, token in enumerate(vocab)}
        idx2token = {idx: token for idx, token in enumerate(vocab)}
        return token2idx, idx2token
    
    def load_data(fpath1, fpath2, maxlen1, maxlen2):
        '''Loads source and target data and filters out too lengthy samples.
        fpath1: source file path. string.
        fpath2: target file path. string.
        maxlen1: source sent maximum length. scalar.
        maxlen2: target sent maximum length. scalar.
    
        Returns
        sents1: list of source sents
        sents2: list of target sents
        '''
        sents1, sents2 = [], []
        with open(fpath1, 'r') as f1, open(fpath2, 'r') as f2:
            for sent1, sent2 in zip(f1, f2):
                if len(sent1.split()) + 1 > maxlen1: continue # 1: </s>
                if len(sent2.split()) + 1 > maxlen2: continue  # 1: </s>
                sents1.append(sent1.strip())
                sents2.append(sent2.strip())
        return sents1, sents2
    
    
    def encode(inp, type, dict):
        '''Converts string to number. Used for `generator_fn`.
        inp: 1d byte array.
        type: "x" (source side) or "y" (target side)
        dict: token2idx dictionary
    
        Returns
        list of numbers
        '''
        inp_str = inp.decode("utf-8")
        if type=="x": tokens = inp_str.split() + ["</s>"]
        else: tokens = ["<s>"] + inp_str.split() + ["</s>"]
    
        x = [dict.get(t, dict["<unk>"]) for t in tokens]
        return x
    
    def generator_fn(sents1, sents2, vocab_fpath):
        '''Generates training / evaluation data
        sents1: list of source sents
        sents2: list of target sents
        vocab_fpath: string. vocabulary file path.
    
        yields
        xs: tuple of
            x: list of source token ids in a sent
            x_seqlen: int. sequence length of x
            sent1: str. raw source (=input) sentence
        labels: tuple of
            decoder_input: decoder_input: list of encoded decoder inputs
            y: list of target token ids in a sent
            y_seqlen: int. sequence length of y
            sent2: str. target sentence
        '''
        token2idx, _ = load_vocab(vocab_fpath)
        for sent1, sent2 in zip(sents1, sents2):
            x = encode(sent1, "x", token2idx)
            y = encode(sent2, "y", token2idx)
            decoder_input, y = y[:-1], y[1:]
    
            x_seqlen, y_seqlen = len(x), len(y)
            yield (x, x_seqlen, sent1), (decoder_input, y, y_seqlen, sent2)
    
    def input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=False):
        '''Batchify data
        sents1: list of source sents
        sents2: list of target sents
        vocab_fpath: string. vocabulary file path.
        batch_size: scalar
        shuffle: boolean
    
        Returns
        xs: tuple of
            x: int32 tensor. (N, T1)
            x_seqlens: int32 tensor. (N,)
            sents1: str tensor. (N,)
        ys: tuple of
            decoder_input: int32 tensor. (N, T2)
            y: int32 tensor. (N, T2)
            y_seqlen: int32 tensor. (N, )
            sents2: str tensor. (N,)
        '''
        shapes = (([None], (), ()),
                  ([None], [None], (), ()))
        types = ((tf.int32, tf.int32, tf.string),
                 (tf.int32, tf.int32, tf.int32, tf.string))
        paddings = ((0, 0, ''),
                    (0, 0, 0, ''))
    
        dataset = tf.data.Dataset.from_generator(
            generator_fn,
            output_shapes=shapes,
            output_types=types,
            args=(sents1, sents2, vocab_fpath))  # <- arguments for generator_fn. converted to np string arrays
    
        if shuffle: # for training
            dataset = dataset.shuffle(128*batch_size)
    
        dataset = dataset.repeat()  # iterate forever
        dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)
    
        return dataset
    
    def get_batch(fpath1, fpath2, maxlen1, maxlen2, vocab_fpath, batch_size, shuffle=False):
        '''Gets training / evaluation mini-batches
        fpath1: source file path. string.
        fpath2: target file path. string.
        maxlen1: source sent maximum length. scalar.
        maxlen2: target sent maximum length. scalar.
        vocab_fpath: string. vocabulary file path.
        batch_size: scalar
        shuffle: boolean
    
        Returns
        batches
        num_batches: number of mini-batches
        num_samples
        '''
        sents1, sents2 = load_data(fpath1, fpath2, maxlen1, maxlen2)
        batches = input_fn(sents1, sents2, vocab_fpath, batch_size, shuffle=shuffle)
        num_batches = calc_num_batches(len(sents1), batch_size)
        return batches, num_batches, len(sents1)

    6、看一下相关模型model.py

    # -*- coding: utf-8 -*-
    # /usr/bin/python3
    '''
    Feb. 2019 by kyubyong park.
    kbpark.linguist@gmail.com.
    https://www.github.com/kyubyong/transformer
    
    Transformer network
    '''
    import tensorflow as tf
    
    from data_load import load_vocab
    from modules import get_token_embeddings, ff, positional_encoding, multihead_attention, label_smoothing, noam_scheme
    from utils import convert_idx_to_token_tensor
    from tqdm import tqdm
    import logging
    
    logging.basicConfig(level=logging.INFO)
    
    class Transformer:
        '''
        xs: tuple of
            x: int32 tensor. (N, T1)
            x_seqlens: int32 tensor. (N,)
            sents1: str tensor. (N,)
        ys: tuple of
            decoder_input: int32 tensor. (N, T2)
            y: int32 tensor. (N, T2)
            y_seqlen: int32 tensor. (N, )
            sents2: str tensor. (N,)
        training: boolean.
        '''
        def __init__(self, hp):
            self.hp = hp
            self.token2idx, self.idx2token = load_vocab(hp.vocab)
            self.embeddings = get_token_embeddings(self.hp.vocab_size, self.hp.d_model, zero_pad=True)
    
        def encode(self, xs, training=True):
            '''
            Returns
            memory: encoder outputs. (N, T1, d_model)
            '''
            with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
                x, seqlens, sents1 = xs
    
                # src_masks
                src_masks = tf.math.equal(x, 0) # (N, T1)
    
                # embedding
                enc = tf.nn.embedding_lookup(self.embeddings, x) # (N, T1, d_model)
                enc *= self.hp.d_model**0.5 # scale
    
                enc += positional_encoding(enc, self.hp.maxlen1)
                enc = tf.layers.dropout(enc, self.hp.dropout_rate, training=training)
    
                ## Blocks
                for i in range(self.hp.num_blocks):
                    with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
                        # self-attention
                        enc = multihead_attention(queries=enc,
                                                  keys=enc,
                                                  values=enc,
                                                  key_masks=src_masks,
                                                  num_heads=self.hp.num_heads,
                                                  dropout_rate=self.hp.dropout_rate,
                                                  training=training,
                                                  causality=False)
                        # feed forward
                        enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
            memory = enc
            return memory, sents1, src_masks
    
        def decode(self, ys, memory, src_masks, training=True):
            '''
            memory: encoder outputs. (N, T1, d_model)
            src_masks: (N, T1)
    
            Returns
            logits: (N, T2, V). float32.
            y_hat: (N, T2). int32
            y: (N, T2). int32
            sents2: (N,). string.
            '''
            with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
                decoder_inputs, y, seqlens, sents2 = ys
    
                # tgt_masks
                tgt_masks = tf.math.equal(decoder_inputs, 0)  # (N, T2)
    
                # embedding
                dec = tf.nn.embedding_lookup(self.embeddings, decoder_inputs)  # (N, T2, d_model)
                dec *= self.hp.d_model ** 0.5  # scale
    
                dec += positional_encoding(dec, self.hp.maxlen2)
                dec = tf.layers.dropout(dec, self.hp.dropout_rate, training=training)
    
                # Blocks
                for i in range(self.hp.num_blocks):
                    with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
                        # Masked self-attention (Note that causality is True at this time)
                        dec = multihead_attention(queries=dec,
                                                  keys=dec,
                                                  values=dec,
                                                  key_masks=tgt_masks,
                                                  num_heads=self.hp.num_heads,
                                                  dropout_rate=self.hp.dropout_rate,
                                                  training=training,
                                                  causality=True,
                                                  scope="self_attention")
    
                        # Vanilla attention
                        dec = multihead_attention(queries=dec,
                                                  keys=memory,
                                                  values=memory,
                                                  key_masks=src_masks,
                                                  num_heads=self.hp.num_heads,
                                                  dropout_rate=self.hp.dropout_rate,
                                                  training=training,
                                                  causality=False,
                                                  scope="vanilla_attention")
                        ### Feed Forward
                        dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])
    
            # Final linear projection (embedding weights are shared)
            weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
            logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
            y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
    
            return logits, y_hat, y, sents2
    
        def train(self, xs, ys):
            '''
            Returns
            loss: scalar.
            train_op: training operation
            global_step: scalar.
            summaries: training summary node
            '''
            # forward
            memory, sents1, src_masks = self.encode(xs)
            logits, preds, y, sents2 = self.decode(ys, memory, src_masks)
    
            # train scheme
            y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size))
            ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_)
            nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["<pad>"]))  # 0: <pad>
            loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)
    
            global_step = tf.train.get_or_create_global_step()
            lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps)
            optimizer = tf.train.AdamOptimizer(lr)
            train_op = optimizer.minimize(loss, global_step=global_step)
    
            tf.summary.scalar('lr', lr)
            tf.summary.scalar("loss", loss)
            tf.summary.scalar("global_step", global_step)
    
            summaries = tf.summary.merge_all()
    
            return loss, train_op, global_step, summaries
    
        def eval(self, xs, ys):
            '''Predicts autoregressively
            At inference, input ys is ignored.
            Returns
            y_hat: (N, T2)
            '''
            decoder_inputs, y, y_seqlen, sents2 = ys
    
            decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx["<s>"]
            ys = (decoder_inputs, y, y_seqlen, sents2)
    
            memory, sents1, src_masks = self.encode(xs, False)
    
            logging.info("Inference graph is being built. Please be patient.")
            for _ in tqdm(range(self.hp.maxlen2)):
                logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False)
                if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break
    
                _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
                ys = (_decoder_inputs, y, y_seqlen, sents2)
    
            # monitor a random sample
            n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
            sent1 = sents1[n]
            pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
            sent2 = sents2[n]
    
            tf.summary.text("sent1", sent1)
            tf.summary.text("pred", pred)
            tf.summary.text("sent2", sent2)
            summaries = tf.summary.merge_all()
    
            return y_hat, summaries
  • 相关阅读:
    Java基础知识强化97:final、finally、finally区别
    Java基础知识强化之集合框架笔记02:集合的继承体系图解
    Java基础知识强化之集合框架笔记01:集合的由来与数组的区别
    Java基础知识强化96:Calendar类之获取任意年份的2月有多少天的案例
    Java基础知识强化95:Calendar类之Calendar类的add()和set()方法
    Gym
    Gym
    Good Bye 2015 B. New Year and Old Property —— dfs 数学
    HDU1873 看病要排队 —— 优先队列(STL)
    HDU5877 Weak Pair dfs + 线段树/树状数组 + 离散化
  • 原文地址:https://www.cnblogs.com/xiximayou/p/13336729.html
Copyright © 2011-2022 走看看