zoukankan      html  css  js  c++  java
  • Pytorch Hook 函数

    Pytorch中带了Hook函数,Hook的中文意思是’钩子‘,刚开始看到这个词语就有点害怕,一是不认识这个词,翻译成中文也不了解这是什么意思;二是常规调库搭积木时也没有用到过这个函数;直到读到下面文章,https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca 我对hook有了初步的理解

    1. 为什么需要 hook 函数

    • 当我们的神经网络出现 bug 时,没法产生我们所期望的输出时,我们通常需要进行debug,一般的做法是在 forward 函数中写 print函数,输出某些层的输出;或者通过添加断点来进行单步调试,以观察中间层的输出。这在 pytorch 中就可以通过 hook 函数来实现。
    • 由于pytorhc的自动求导机制,即当设置参数的 requires_grad=True时,那么涉及这组参数的一系列操作将会被autograd记录用以反向求导。但是在自动求导机制中只保存叶子节点,也就是中间变量在计算完成梯度后会自动释放以节省空间
    x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
    y = x * 2
    z = torch.mean(y)
    z.backward()
    print("x.grad =", x.grad)
    print("y.grad =", y.grad)
    print("z.grad =", z.grad)
    

    输出

    x.grad = tensor([1., 1.])
    y.grad = None
    z.grad = None
    

    因此,如果我们想知道 y 和 z 的梯度,就需要用到 hook 函数。
    也就是说,hook 函数用以获取我们不方便获得的一些中间变量。

    2. 什么是hook函数

    • hook 其实就是一个普通的函数或类,准确的说是一个可调用的对象,callable object. 需要什么样的功能我们可根据自己的需求自己写。总之,hook 和我们常规写的函数和类没有区别。但是 pytorch 有一个机制,我们可以把写好的函数或者类注册到某些 layer (nn.Module)上,这样子当这些 layer 在执行 forward 或者 backward时其输入或输出就会自动传到我们写好的hook函数中执行。因此,这些函数就像一个钩子一样,可以挂到某些layer上或者从这些 layer 上解挂。这就是名字叫 hook 的原因。

    3. Pytorch 提供的 Hook

    • 一般来说,我们在 debug 时想知道的内容有三种
      • 某个模块的输入是什么,即 在跑 forward前模块的输入
      • 某个模块的输出是什么,即 在跑 forward后模块的输出
      • 某个模块的梯度反传后是什么,即 在跑 backward后模块的状态
    • 将这三个状态的数据与我们所期望的数据进行比较,我们就可以知道哪里出现了问题;Pytorch 就提供了这三种钩子,把这三种钩子挂到指定的layer上,这些layer的输入输出就会对应的作为参数传到hook函数中运行hook函数。下图引用自
      image.png
    • pytorch nn.Module源码中就提供了这三个属性
            self._backward_hooks = OrderedDict()
            self._forward_hooks = OrderedDict()
            self._forward_pre_hooks = OrderedDict()
    
    • 同时提供了三个注册方法,也就是往上面三个dict中填值的方法
      • forward prehook (executing before the forward pass),
      • forward hook (executing after the forward pass),
      • backward hook (executing after the backward pass).

    register_forward_pre_hookforward前运行,获取这一个 module 的输入

        def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
            r"""Registers a forward pre-hook on the module.
    
            The hook will be called every time before :func:`forward` is invoked.
            It should have the following signature::
    
                hook(module, input) -> None or modified input
    
            The input contains only the positional arguments given to the module.
            Keyword arguments won't be passed to the hooks and only to the ``forward``.
            The hook can modify the input. User can either return a tuple or a
            single modified value in the hook. We will wrap the value into a tuple
            if a single value is returned(unless that value is already a tuple).
    
            Returns:
                :class:`torch.utils.hooks.RemovableHandle`:
                    a handle that can be used to remove the added hook by calling
                    ``handle.remove()``
            """
            handle = hooks.RemovableHandle(self._forward_pre_hooks)
            self._forward_pre_hooks[handle.id] = hook
            return handle
    

    register_forward_hook在forward后运行,获取这个module的input和output信息

        def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle:
            r"""Registers a forward hook on the module.
    
            The hook will be called every time after :func:`forward` has computed an output.
            It should have the following signature::
    
                hook(module, input, output) -> None or modified output
    
            The input contains only the positional arguments given to the module.
            Keyword arguments won't be passed to the hooks and only to the ``forward``.
            The hook can modify the output. It can modify the input inplace but
            it will not have effect on forward since this is called after
            :func:`forward` is called.
    
            Returns:
                :class:`torch.utils.hooks.RemovableHandle`:
                    a handle that can be used to remove the added hook by calling
                    ``handle.remove()``
            """
            handle = hooks.RemovableHandle(self._forward_hooks)
            self._forward_hooks[handle.id] = hook
            return handle
    

    register_backward_hook,获取反向传播中module的grad_in, grad_out信息

        def register_backward_hook(
            self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]]
        ) -> RemovableHandle:
            r"""Registers a backward hook on the module.
    
            This function is deprecated in favor of :meth:`nn.Module.register_full_backward_hook` and
            the behavior of this function will change in future versions.
    
            Returns:
                :class:`torch.utils.hooks.RemovableHandle`:
                    a handle that can be used to remove the added hook by calling
                    ``handle.remove()``
    
            """
            if self._is_full_backward_hook is True:
                raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a "
                                   "single Module. Please use only one of them.")
    
            self._is_full_backward_hook = False
    
            handle = hooks.RemovableHandle(self._backward_hooks)
            self._backward_hooks[handle.id] = hook
            return handle
    

    4.hook 实例

    这里我们通过在ResNet34的每一层插入一个钩子,来获取ResNet34每一层的输出,即这里我们使用 register_forward_hook
    使用下面图片作为输入
    image.png

    import torch
    from torchvision.models import resnet34
    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = resnet34(pretrained=True)
    model = model.to(device)
    
    class SaveOutput:
        def __init__(self):
            self.outputs = []
            self.inputs = []
            
        def __call__(self, module, module_in, module_out):
            print(module)
            self.inputs.append(module_in)
            self.outputs.append(module_out)
            
        def clear(self):
            self.outputs = []
            self.inputs = []
            
    
    save_output = SaveOutput()
    
    hook_handles = []
    
    for layer in model.modules():
        if isinstance(layer, torch.nn.modules.conv.Conv2d):
            handle = layer.register_forward_hook(save_output)
            hook_handles.append(handle)
            
            
    from PIL import Image
    from torchvision import transforms as T
    
    img = Image.open('./cat.jpeg')
    transform = T.Compose([T.Resize((224,224)),
                           T.ToTensor(),
                           T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.485, 0.456, 0.406],)
                          ])
    x = transform(img).unsqueeze(0).to(device)
    out = model(x)
    

    输出

    Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

    > save_output.outputs[0].size()
    torch.Size([1, 64, 112, 112])
    > save_output.inputs[0][0].size()
    torch.Size([1, 3, 224, 224])

    可以看到模块,模块的输入输出会自动作为参数传入到我们写的SaveOutput实例中并调用该实例。
    下面是每一层的输出可视化
    image.png


    对于 Tensor的 hook

    x = torch.tensor([1,2],dtype=torch.float32,requires_grad=True)
    y = x * 2
    y.register_hook(print)
    z = torch.mean(y)
    z.backward()
    

    输出:

    tensor([0.5000, 0.5000])
    

    hook 应用于 模型剪枝 model pruning
    https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

  • 相关阅读:
    Java输出错误信息与调试信息
    Java实现两个变量的互换(不借助第3个变量)
    Java用三元运算符判断奇数和偶数
    使用webpack-dev-server设置反向代理解决前端跨域问题
    springboot解决跨域问题(Cors)
    Spring boot集成swagger2
    Redis学习汇总
    【年终总结】2017年迟来的总结
    Springboot项目maven多模块拆分
    Maven实现多环境打包
  • 原文地址:https://www.cnblogs.com/qiulinzhang/p/14293662.html
Copyright © 2011-2022 走看看