zoukankan      html  css  js  c++  java
  • [torch] pytorch hook学习

    pytorch hook学习

    register_hook

    import torch
    x = torch.Tensor([0,1,2,3]).requires_grad_()
    y = torch.Tensor([4,5,6,7]).requires_grad_()
    w = torch.Tensor([1,2,3,4]).requires_grad_()
    z = x+y;
    o = w.matmul(z) # o = w(x+y) 中间变量z
    o.backward()
    print(x.grad,y.grad,z.grad,w.grad,o.grad)
    

    这里的o和z都是中间变量,不是通过指定值来定义的变量,所以是中间变量,所以pytorch并不存储这些变量的梯度。

    对于中间变量z,hook的使用方式为: z.register_hook(hook_fn),其中 hook_fn为一个用户自定义的函数,其签名为:hook_fn(grad) -> Tensor or None。

    它的输入为变量 z 的梯度,输出为一个 Tensor 或者是 None (None 一般用于直接打印梯度)。反向传播时,梯度传播到变量 z,再继续向前传播之前,将会传入 hook_fn。如果 hook_fn的返回值是 None,那么梯度将不改变,继续向前传播,如果 hook_fn的返回值是 Tensor 类型,则该 Tensor 将取代 z 原有的梯度,向前传播。

    import torch
    x = torch.Tensor([0,1,2,3]).requires_grad_()
    y = torch.Tensor([4,5,6,7]).requires_grad_()
    w = torch.Tensor([1,2,3,4]).requires_grad_()
    z = x+y;
    def hook_fn(grad):
        print(grad)
        return None
    
    z.register_hook(hook_fn)
    o = w.matmul(z) # o = w(x+y) 中间变量z
    o.backward()
    print(x.grad,y.grad,w.grad,z.grad,o.grad)
    
    

    register_forward_hook

    register_forward_hook的作用是获取前向传播过程中,各个网络模块的输入和输出。对于模块 module,其使用方式为:module.register_forward_hook(hook_fn) 。其中 hook_fn的签名为:

    hook_fn(module, input, output) -> None
    

    eg

    import torch
    from torch import nn
    class Model(nn.Module):
        def __init__(self):
            super(Model,self).__init__()
            self.fc1 = nn.Linear(3,4) # WT * X + bias
            self.relu1 = nn.ReLU()
            self.fc2 = nn.Linear(4,1)
            self.init()
        def init(self):
            with torch.no_grad():
                # WT * X + bias,所以W为4*3的矩阵,bias为1*4
                self.fc1.weight = torch.nn.Parameter(
                    torch.Tensor([[1., 2., 3.],
                                  [-4., -5., -6.],
                                  [7., 8., 9.],
                                  [-10., -11., -12.]]))
                self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
                self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
                self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0]))
    
        def forward(self,x):
            o = self.fc1(x)
            o = self.relu1(o)
            o = self.fc2(o)
            return o
    def hook_fn_forward(module,input,output):
        print(module)
        print(input)
        print(output)
    
    
    model = Model()
    modules = model.named_children()
    '''
    named_children()
    Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
    '''
    for name,module in modules:
        # 这里的name就是自己定义的self.xx的xx。如上面的fc1,fc2.
        # module代指的就是fc1代表的module等等
        module.register_forward_hook(hook_fn_forward)
    x = torch.Tensor([[1.0,1.0,1.0]]).requires_grad_()
    o = model(x)
    o.backward()
     '''
     Linear(in_features=3, out_features=4, bias=True)
    (tensor([[1., 1., 1.]], requires_grad=True),)
    tensor([[  7., -13.,  27., -29.]], grad_fn=<AddmmBackward>)
    ReLU()
    (tensor([[  7., -13.,  27., -29.]], grad_fn=<AddmmBackward>),)
    tensor([[ 7.,  0., 27.,  0.]], grad_fn=<ReluBackward0>)
    Linear(in_features=4, out_features=1, bias=True)
    (tensor([[ 7.,  0., 27.,  0.]], grad_fn=<ReluBackward0>),)
    tensor([[89.]], grad_fn=<AddmmBackward>)
     
     '''
    

    register_backward_hook

    理同前者。得到梯度值。

    hook_fn(module, grad_input, grad_output) -> Tensor or None
    

    上面的代码forward全部替换为backward,结果为:

    '''
    Linear(in_features=4, out_features=1, bias=True)
    (tensor([1.]), tensor([[1., 2., 3., 4.]]), tensor([[ 7.],
            [ 0.],
            [27.],
            [ 0.]]))
    (tensor([[1.]]),)
    ReLU()
    (tensor([[1., 0., 3., 0.]]),)
    (tensor([[1., 2., 3., 4.]]),)
    Linear(in_features=3, out_features=4, bias=True)
    (tensor([1., 0., 3., 0.]), tensor([[22., 26., 30.]]), tensor([[1., 0., 3., 0.],
            [1., 0., 3., 0.],
            [1., 0., 3., 0.]]))
    (tensor([[1., 0., 3., 0.]]),)
    '''
    

    register_backward_hook只能操作简单模块,而不能操作包含多个子模块的复杂模块。 如果对复杂模块用了 backward hook,那么我们只能得到该模块最后一次简单操作的梯度信息。

    可以这么用,可以得到一个模块的梯度。

    class Mymodel(nn.Module):
    	......
        
    model = Mymodel()
    model.register_backward_hook(hook_fn_backward)
    
  • 相关阅读:
    Android基于mAppWidget实现手绘地图(五)--如何创建地图资源
    Android基于mAppWidget实现手绘地图(四)--如何附加javadoc
    Android基于mAppWidget实现手绘地图(三)--环境搭建
    Android基于mAppWidget实现手绘地图(二)--概要
    Android基于mAppWidget实现手绘地图(一)--简介
    网络通信之Socket与LocalSocket的比较
    Python-Django 整合Django和jquery-easyui
    Python-Django 第一个Django app
    RobotFramework 官方demo Quick Start Guide rst配置文件分析
    RobotFramework RobotFramework官方demo Quick Start Guide浅析
  • 原文地址:https://www.cnblogs.com/aoru45/p/11297066.html
Copyright © 2011-2022 走看看