zoukankan      html  css  js  c++  java
  • pytorch mode = sequential() 为何model(input)这样调用就直接执行了forward

    pytorch mode = sequential() 为何model(input)这样调用就直接执行了forward

    这是因为它实现了__call__方法model(input)相当于调用了model.call(input)

    下面这个是__call__的源码

      def __call__(self, *input, **kwargs):
            for hook in self._forward_pre_hooks.values():
                result = hook(self, input)
                if result is not None:
                    if not isinstance(result, tuple):
                        result = (result,)
                    input = result
            if torch._C._get_tracing_state():
                result = self._slow_forward(*input, **kwargs)
            else:
                result = self.forward(*input, **kwargs)
            for hook in self._forward_hooks.values():
                hook_result = hook(self, input, result)
                if hook_result is not None:
                    result = hook_result
            if len(self._backward_hooks) > 0:
                var = result
                while not isinstance(var, torch.Tensor):
                    if isinstance(var, dict):
                        var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                    else:
                        var = var[0]
                grad_fn = var.grad_fn
                if grad_fn is not None:
                    for hook in self._backward_hooks.values():
                        wrapper = functools.partial(hook, self)
                        functools.update_wrapper(wrapper, hook)
                        grad_fn.register_hook(wrapper)
            return result
    
  • 相关阅读:
    B3
    B2
    b1
    个人作业——软件工程实践总结作业
    Beta 答辩总结
    Beta 冲刺 (7/7)
    Beta 冲刺 (6/7)
    Beta 冲刺 (5/7)
    Beta 冲刺 (4/7)
    Beta 冲刺 (3/7)
  • 原文地址:https://www.cnblogs.com/ailitao/p/11787540.html
Copyright © 2011-2022 走看看