zoukankan      html  css  js  c++  java
  • MXNet中bucket机制注记

    Preface

    之前看API以为bucket是一个根植于底层操作的接口(MXNet doc功不可没 -_-|| )。从LSTM看过来,接触到了一些相关的程序,后面再把bucketing_module.py那部分查看了下,发现bucket只是一个应用层机制,主要的实现存在于module/bucketing_module.py里面。原理清晰,实现简洁,在这做个记号。

    Code & Comments

    先放些相关的链接,做个预备。

    1. MXNet 官方的文档( ucao 出个文档真不容易,还带时效性...)
    2. 大神的blog阐述,鞭辟入里
    3. 之前关于LSTM的blog
      鉴于大神已经在这篇[blog]里面说得生动透彻了,这里就能省就省,然后说些大神没功夫顾及的细节。
      另外考虑到MXNet的链接经常表现出不靠谱的症状(kuxia),归结一下1中有些用的结论:要使用bucket机制,初始化Module时传入的symbol应该是一个函数,这个函数在被调用时将被传入迭代器中的bucket_key参数

    从调用路径的顺序来走一遍把。
    fit里面经过bind,init等操作,后面会调用prepare对预取出的数据(如果有)进行准备:

    # module/bucketing_module.py
        def prepare(self, data_batch):
            """Prepares a data batch for forward.
    
            Parameters
            ----------
            data_batch : DataBatch
            """
            # perform bind if haven't done so
            assert self.binded and self.params_initialized
            bucket_key = data_batch.bucket_key
            original_bucket_key = self._curr_bucket_key
            data_shapes = data_batch.provide_data
            label_shapes = data_batch.provide_label
            self.switch_bucket(bucket_key, data_shapes, label_shapes)
            # switch back
            self.switch_bucket(original_bucket_key, None, None)
    

    显然,switch_bucket就是负责进行重新绑定的:

    # module/bucketing_module.py
        def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
             assert self.binded, 'call bind before switching bucket'
            if not bucket_key in self._buckets:    # check if there is already...
                symbol, data_names, label_names = self._sym_gen(bucket_key)
                module = Module(symbol, data_names, label_names,
                                logger=self.logger, context=self._context,
                                work_load_list=self._work_load_list,
                                fixed_param_names=self._fixed_param_names,
                                state_names=self._state_names)
                module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                            self._curr_module.inputs_need_grad,
                            force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
                self._buckets[bucket_key] = module
    
            self._curr_module = self._buckets[bucket_key]
            self._curr_bucket_key = bucket_key
    

    逻辑很明白,_curr_module里面放了众多的module,这些module的参数全都指向同一组。如果出入的bucket_key没有出现过,就bind一个并放入_curr_module列表里面去;如果已经有了(包括刚刚bind出来的),就切换到那个module上。

    Misc

    其他有一些相关的材料顺带放在这。

    1. 上一篇blog里面推测bucket机制可能会对补齐的那部分进行处理,这一点与io.py里面的DataBatchpad变量有些联系。在module/base_module.py中,查找pad的引用,发现和io.py里面的注释一致,只在prediction的时候进行了使用,训练的时候被忽视。
    2. exmple/rnn/bucketing里面有更高层接口的使用示例。
  • 相关阅读:
    springMVC源码学习地址
    JVM架构和GC垃圾回收机制详解
    String StringBuffer和StringBuilder区别及性能
    java reflect反射获取方法变量参数
    springMVC数据模型model,modelmap,map,@ModelAttribute的相互关系
    java abstract构造函数调用
    springMVC源码学习之addFlashAttribute源码分析
    LeetCode 404. Sum of Left Leaves
    利用JavaFX访问MySQL数据库
    LeetCode 111. Minimum Depth of Binary Tree
  • 原文地址:https://www.cnblogs.com/chenyliang/p/8060014.html
Copyright © 2011-2022 走看看