zoukankan      html  css  js  c++  java
  • [NN] Guided Backpropgation 可视化

    Pytorch Guided Backpropgation

    Intro

    guided backpropgation通过修改RELU的梯度反传,使得小于0的部分不反传,只传播大于0的部分,这样到第一个conv层的时候得到的梯度就是对后面relu激活起作用的梯度,这时候我们对这些梯度进行可视化,得到的就是对网络起作用的区域。(实际上可视化的是梯度)。

    简单记一下。用到hook的神经网络可视化方法。

    code

    import torch
    import torch.nn as nn
    from torchvision import transforms,models
    import re
    from models.densenet import densenet121
    from PIL import Image
    import numpy as np
    import matplotlib.pyplot as plt
    class Guided_Prop():
        def __init__(self,model):
            self.model = model
            self.model.eval()
            self.out_img = None
            self.activation_maps = []
            
        def register_hooks(self):
            def register_first_layer_hook(module,grad_in,grad_out):
                self.out_img = grad_in[0] #(b,c,h,w) -> (c,h,w)
            def forward_hook_fn(module,input_feature,output_feature):
                self.activation_maps.append(output_feature)
            def backward_hook_fn(module,grad_in,grad_out):
                grad = self.activation_maps.pop()
                grad[grad > 0] = 1
                g_positive = torch.clamp(grad_out[0],min = 0.)
                result_grad = grad * g_positive
                return (result_grad,)
    
            modules = list(self.model.features.named_children())
            for name,module in modules:
                if isinstance(module,nn.ReLU):
                    module.register_forward_hook(forward_hook_fn)
                    module.register_backward_hook(backward_hook_fn)
            first_layer = modules[0][1]
            first_layer.register_backward_hook(register_first_layer_hook)
    
        def visualize(self,input_image):
            softmax = nn.Softmax(dim = 1)
            idx_tensor = torch.tensor([float(i) for i in range(61)])
            self.register_hooks()
            self.model.zero_grad()
            out = self.model(input_image) # [[b,n],[b,n],[b,n]]
            yaw = softmax(out[0])
            yaw = torch.sum(yaw * idx_tensor,dim = 1) * 3 - 90.
            pitch = softmax(out[1])
            pitch = torch.sum(pitch * idx_tensor,dim = 1) * 3 - 90.
            roll = softmax(out[2])
            roll = torch.sum(roll * idx_tensor,dim = 1) * 3 - 90.
            
    
            
    
            #print(yaw)
            out = yaw + pitch + roll
            out.backward()
            result = self.out_img.data[0].permute(1,2,0) # chw -> hwc(opencv)
            return result.numpy()
    def normalize(I):
        norm = (I-I.mean())/I.std()
        norm = norm * 0.1
        norm = norm + 0.5
        norm = norm.clip(0, 1)
        return norm
    if __name__ == "__main__":
        input_size = 224
        model = densenet121(pretrained = False,num_classes = 61)
        model.load_state_dict(torch.load("./ckpt/DenseNet/model_2692_.pkl"))
        
        img = Image.open("/media/xueaoru/其他/ML/head_pose_work/brick/head_and_heads/test/BIWI00009409_-17_+1_+17.png")
        transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        tensor  = transform(img).unsqueeze(0).requires_grad_()
        
        viz = Guided_Prop(model)
        
        result = viz.visualize(tensor)
        result = normalize(result)
    
        plt.imshow(result)
        plt.show()
    
    

    由于是多任务问题,所以直接拿结果反传,对于一般的分类问题,可以给定target来用gt用one-hot反传。

    head pose estimation 的梯度可视化。

  • 相关阅读:
    jdbc连接数据库(mysql,sqlserver,oracle)
    简单粗暴将sqlserver表以及数据迁移到oracle
    LXD 2.0 系列(五):镜像管理
    LXD 2.0 系列(十二):调试,及给 LXD 做贡献
    LXD 2.0 系列(七):LXD 中的 Docker
    LXD 2.0 系列(四):资源控制
    LXD 2.0 系列(三):你的第一个 LXD 容器
    LXD 2.0 系列(二):安装与配置
    LXD 2.0 系列(一):LXD 入门
    Debian-linux 网卡配置
  • 原文地址:https://www.cnblogs.com/aoru45/p/11347226.html
Copyright © 2011-2022 走看看