zoukankan      html  css  js  c++  java
  • MXNet中bucket机制注记

    Preface

    之前看API以为bucket是一个根植于底层操作的接口(MXNet doc功不可没 -_-|| )。从LSTM看过来,接触到了一些相关的程序,后面再把bucketing_module.py那部分查看了下,发现bucket只是一个应用层机制,主要的实现存在于module/bucketing_module.py里面。原理清晰,实现简洁,在这做个记号。

    Code & Comments

    先放些相关的链接,做个预备。

    1. MXNet 官方的文档( ucao 出个文档真不容易,还带时效性...)
    2. 大神的blog阐述,鞭辟入里
    3. 之前关于LSTM的blog
      鉴于大神已经在这篇[blog]里面说得生动透彻了,这里就能省就省,然后说些大神没功夫顾及的细节。
      另外考虑到MXNet的链接经常表现出不靠谱的症状(kuxia),归结一下1中有些用的结论:要使用bucket机制,初始化Module时传入的symbol应该是一个函数,这个函数在被调用时将被传入迭代器中的bucket_key参数

    从调用路径的顺序来走一遍把。
    fit里面经过bind,init等操作,后面会调用prepare对预取出的数据(如果有)进行准备:

    # module/bucketing_module.py
        def prepare(self, data_batch):
            """Prepares a data batch for forward.
    
            Parameters
            ----------
            data_batch : DataBatch
            """
            # perform bind if haven't done so
            assert self.binded and self.params_initialized
            bucket_key = data_batch.bucket_key
            original_bucket_key = self._curr_bucket_key
            data_shapes = data_batch.provide_data
            label_shapes = data_batch.provide_label
            self.switch_bucket(bucket_key, data_shapes, label_shapes)
            # switch back
            self.switch_bucket(original_bucket_key, None, None)
    

    显然,switch_bucket就是负责进行重新绑定的:

    # module/bucketing_module.py
        def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
             assert self.binded, 'call bind before switching bucket'
            if not bucket_key in self._buckets:    # check if there is already...
                symbol, data_names, label_names = self._sym_gen(bucket_key)
                module = Module(symbol, data_names, label_names,
                                logger=self.logger, context=self._context,
                                work_load_list=self._work_load_list,
                                fixed_param_names=self._fixed_param_names,
                                state_names=self._state_names)
                module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                            self._curr_module.inputs_need_grad,
                            force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
                self._buckets[bucket_key] = module
    
            self._curr_module = self._buckets[bucket_key]
            self._curr_bucket_key = bucket_key
    

    逻辑很明白,_curr_module里面放了众多的module,这些module的参数全都指向同一组。如果出入的bucket_key没有出现过,就bind一个并放入_curr_module列表里面去;如果已经有了(包括刚刚bind出来的),就切换到那个module上。

    Misc

    其他有一些相关的材料顺带放在这。

    1. 上一篇blog里面推测bucket机制可能会对补齐的那部分进行处理,这一点与io.py里面的DataBatchpad变量有些联系。在module/base_module.py中,查找pad的引用,发现和io.py里面的注释一致,只在prediction的时候进行了使用,训练的时候被忽视。
    2. exmple/rnn/bucketing里面有更高层接口的使用示例。
  • 相关阅读:
    后台数值往前台传值,能获取到值,页面显示不出来的问题
    总结jquery中对select和option的基本操作
    使用<input type="image" src="...">标签会引发页面刷新的问题
    用java来实现验证码功能。
    用java来实现验证码功能(本帖为转载贴),作为个人学习收藏用
    使用java实现发送邮件的功能
    java中的中文参数存到数据库乱码问题
    模糊查询时,页面没有数据,数据库编辑器里可以正常显示数据
    c# 无法加载 DLL xxxxxxxx找不到指定的模块。 (异常来自HRESULT:0x8007007E)。的一个解决方法
    关于C#调用C++ 的DLL传送字符串显示乱码的解决
  • 原文地址:https://www.cnblogs.com/chenyliang/p/8060014.html
Copyright © 2011-2022 走看看