zoukankan      html  css  js  c++  java
  • [The Annotated Transformer] Iterators

    Iterators

    对torchtext的batch实现的修改算法原理

    Batching matters a ton for speed. We want to have very evenly divided batches, with absolutely minimal padding. To do this we have to hack a bit around the default torchtext batching. This code patches their default batching to make sure we search over enough sentences to find tight batches.

    这里是对torchtext中默认的batching操作进行的优化修改。

    参考:https://towardsdatascience.com/how-to-use-torchtext-for-neural-machine-translation-plus-hack-to-make-it-5x-faster-77f3884d95

    Torchtext本身已经很好了,并且sort_key使得dataset中的数据排序,这样batching后序列长度相近的会被放在同一个batch中,可以很大程度上降低padding的个数。

    但是下面代码又进行了优化:根据每个batch中序列的最大长度,动态更改batch_size,使得可以更好的利用计算资源。

    举个例子:

    假设你的RAM每个iteration可以处理1500个tokens, batch_size = 20, 那么只有当batch中的序列长度为sequence length = 1500 / 20 = 75时,才可以将计算资源利用完全。

    现实中,每个batch的sequence length的显然是在变化的,那么如果希望尽量多的利用计算资源,就需要可以动态调整当前的batch_size.

    Transformer中的MyIterator重载了data.Iterator中的create_batches函数:

     1 class MyIterator(data.Iterator):
     2     def create_batches(self):
     3         if self.train:
     4             def pool(d, random_shuffler):
     5                 for p in data.batch(d, self.batch_size * 100):
     6                     p_batch = data.batch(
     7                         sorted(p, key=self.sort_key),
     8                         self.batch_size, self.batch_size_fn)
     9                     for b in random_shuffler(list(p_batch)):
    10                         yield b
    11             self.batches = pool(self.data(), self.random_shuffler)
    12             
    13         else:
    14             self.batches = []
    15             for b in data.batch(self.data(), self.batch_size,
    16                                           self.batch_size_fn):
    17                 self.batches.append(sorted(b, key=self.sort_key))
    18 
    19 def rebatch(pad_idx, batch):
    20     "Fix order in torchtext to match ours"
    21     src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
    22     return Batch(src, trg, pad_idx)

    pool函数

    其中pool函数的功能与https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py中定义的class BucketIterator(Iterator)的pool函数功能类似。

    1. 将原始的data分成大小为 100 * batch_size的一些chunks => (以上迭代 p 即为 每个chunk)

    2. 在每个chunk中根据 sort_key 对examples进行排序,并对每个chunk按照batch_size分成100个batch =>

    p_batch = data.batch( sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn) )

    3. 将这些chunks进行shuffle  => (random_shuffler(list(p_batch)))

    4. 在每个chunk中再把examples分成 大小为 batch_size 的 100 个 batch => (以上 b 即为每个 batch)

    5. 生成器每次 yield一个batch  => (yield b)

  • 相关阅读:
    MySQL的四种事务隔离级别
    线上CPU飚高(死循环,死锁...)
    Tomcat7 调优及 JVM 参数优化
    tomcat8.5配置高并发
    Tomcat 8.5 基于 Apache Portable Runtime(APR)库性能优化
    android 高德地图 轨迹平滑处理
    android高德地图绘制线 渐变处理
    按下home键,重新打开,应用重启
    小米9屏下指纹判断
    android 9.0以上charles https抓包
  • 原文地址:https://www.cnblogs.com/shiyublog/p/10919988.html
Copyright © 2011-2022 走看看