zoukankan      html  css  js  c++  java
  • mxnet包含NDArray的列表更新

    Jul 26, 2017

    之前写的用来人工设定batch_sizeacc_xxx发现出现了问题。最终发现是列表更新的问题。
    想想之前的NDArray处理,也是奇葩了。比如,你能告诉我下面这段中注释的与非注释的,产生差别的原理么?...what?!居然会有差别?You kidding?

        def acc_update(self,normsize=1):
            assert self.binded and self.params_initialized and self.optimizer_initialized
    #        self._curr_module._exec_group.grad_arrays=
    #                      [[grad.copyto(grad.context)/normsize if grad is not None else None for grad in grads] for grads in self.grad]
            for acc_grads, mod_grads in zip(self.grad,self._curr_module._exec_group.grad_arrays):
                for acc_grad, mod_grad in zip(acc_grads, mod_grads):
                    if acc_grad is not None:
                        mod_grad = acc_grad.copyto(mod_grad.context)/normsize
                    else:
                        mod_grad = None
            ...
    

    Oct 22, 2017

    最近发现接口又改了((⊙﹏⊙)b),新版的(V0.11.1)里面这样做也不合适,用分片的方法可能是对的(从一些结果上来看,还不能肯定没问题)。

    Oct 23, 2017

    对比了累计更新和一次更新作为一个batch的输出,初步验证程序的正确性。


    两处的目的都很明显:想用self.grad的内容更新self._curr_module._exec_group.grad_arrays
    然而调试的结果是,没被注释掉的能够完成这项预期,另外一个不能(可能是暂时的)归纳出其规律,表现某种单一增长的特征。
    感觉上应该是列表之间的替换,但却没有这样运行
    后面再看问题出在哪?

    Sep 13, 2017

    没有发现可能的问题,打算先放一放了。开始的时候打算从_exec_group.grad_arays的接口入手,发现是从_exec_group._execs[].grad_array中传过来的,但在update的时候,用的是前者,猜测可能在更新前有过同步,在没有找到的情况下,直接将后者del或者赋值为None,但都没有效果,和昨天的情况保持了相同;此外,将上段程序中的normsize设置为0发现,更新后确实也没有显著变化(细微的变化应该是由weight decay引起的——仅在小数点变化),也就是说,acc_update中的self.grad确实起到了作用。所以被凌乱了...mess
    贴上两次的结果对比吧,以伺观code者得焉

    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    5873.4209
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    269556.56
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    1444888.8
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    4637960.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    11257292.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    22884572.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    41182624.0
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    

    上面这段应该是非正常结果的,下面这段是归为正常结果。

    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    5873.4209
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    269556.56
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> 
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    741699.12
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    33154.039
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    30.383278
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    >>> mod.forward(d)
    >>> #mod.get_outputs()[0].asnumpy().max()
    ... abs(mod.get_outputs()[2].asnumpy()[0,label_idx][:] - d.label[1].asnumpy()[0,label_idx][:]).sum()
    30.155006
    >>> #mod.backward()
    ... #mod.update()
    ... mod.acc_backward()
    >>> mod.acc_update()
    
  • 相关阅读:
    WCF Server Console
    Restart IIS With Powershell
    RestartService (recursively)
    Copy Files
    Stopping and Starting Dependent Services
    多线程同步控制 ManualResetEvent AutoResetEvent MSDN
    DTD 简介
    Using Powershell to Copy Files to Remote Computers
    Starting and Stopping Services (IIS 6.0)
    java中的NAN和INFINITY
  • 原文地址:https://www.cnblogs.com/chenyliang/p/7512347.html
Copyright © 2011-2022 走看看