zoukankan      html  css  js  c++  java
  • 卷积网络可解释性复现 | Grad-CAM | ICCV | 2017

    觉得本文不错的可以点个赞。有问题联系作者微信cyx645016617,之后主要转战公众号,不在博客园和CSDN更新。

    论文名称:“Grad-CAM:
    Visual Explanations from Deep Networks via Gradient-based Localization”
    论文地址:https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf
    论文期刊:ICCV International Conference on Computer Vision

    1 综述

    总的来说,卷积网络的可解释性一直是一个很重要的问题,你要用人工智能去说服别人,但是这是一个“黑箱”问题,卷积网络运行机理是一个black box,说不清楚内在逻辑。

    因此很多学者提出了各种各样的可视化来解释的方法。我个人尝试过的、可以从一定角度进行解释的可视化方法有:t-sne降维,attention可视化,可变卷积的可视化等,但是其实这些的可视化方法,并不能直接的对模型进行解释,只是能说明模型分类是准确的

    CAM的全称是Class Activation Mapping,对于分类问题,我们可以直观的通过这种方法,来进行解释方向的可视化。

    grad-CAM是CAM的进阶版本,更加方便实施、即插即用。

    2 CAM

    CAM的原理是实现可解释性的根本,所以我通俗易懂的讲一讲。

    上面是一个传统CNN的结构,通过卷积和池化层后,把特征图拉平成一维,然后是全连接层进行分类。

    那么CAM的网络是什么样子呢?基本和上面的结构相同

    图中有一个GAP池化层,全局平均池化层。这个就是求取每一个通道的均值,可以理解为核是和特征图一样大的一般的平均池化层,假如输出特征图是一个8通道的,224x224的特征图,那么经过GAP这个池化层,就会得到8个数字,一个通道贡献一个数字,这个数字是一个通道的代表

    然后经过GAP之后的一维向量,再跟上一个全连接层,得到类别的概率。

    上图中左边就是经过GAP得到的向量,其数量就是最后一层特征图的通道数,右边的向量的数量就是类别的数量。

    关键来了,CAM的可解释性的逻辑在于:假设我们最终预测的类别是羊驼,也就是说,模型给羊驼的打分最高。我们可以得到,左边向量计算出羊驼的权重值,也就是全连接层中的一部分权重值。这个权重值就是!!!就是最后一层特征图每一个通道的权重值。之前也提到了GAP的输出的一个向量代表着GAP输入的特征图的每一个通道嘛

    这样我们通过最后一个全连接层获取到最后一个特征图的每一个通道对于某一个类别的贡献的权重值。我们对最后一个特征图的每一个通道的加权平均,就是我们得到的CAM对卷积的解释。之后可以上采样到整个图片那么大小,就像是论文给出的样子:

    大家应该明白这个原理了,但是这样要修改模型的结构。之前训练的模型用不了了,这很麻烦,所以才有了Grad-CAM的提出。

    3 Grad-CAM

    Grad-CAM思路和CAM也是相同的,也是需要得到特征图每一个通道的权重值,然后做一个加权和。

    所以关键在于,如何计算这个权重值,论文提出了这样的计算方法:

    其中,z是一个特征图的像素量,就是width*height,可以看到,前面就是CAM的GAP的一个过程,后面的(y^c)是模型给类别c的打分,(A_{ij}^k)就是特征图中ij这个位置的元素值。那么对这个求导,其实就是这个位置的梯度。

    所以用pytorch的实现如下:

    self.model.features.zero_grad()
    self.model.classifier.zero_grad()
    one_hot.backward(retain_graph=True)#仅包含有最大概率值,然后进行反向传播  
    grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()
    weights = np.mean(grads_val, axis=(2, 3))[0, :]#求平均,就是上面这段公式
    # 简单的说上面的逻辑就是先反向传播之后,然后获取对应位置的梯度,然后计算平均。
    

    在论文中作者证明了Grad-CAM和CAM的等价的结论,想了解的可以看看。

    4 pytorch完整代码

    官方提供了github代码:https://github.com/jacobgil/pytorch-grad-cam

    其中关键的地方是:

    class FeatureExtractor():
        """ Class for extracting activations and
        registering gradients from targetted intermediate layers """
    
        def __init__(self, model, target_layers):
            self.model = model
            self.target_layers = target_layers
            self.gradients = []
    
        def save_gradient(self, grad):
            self.gradients.append(grad)
    
        def __call__(self, x):
            outputs = []
            self.gradients = []
            for name, module in self.model._modules.items():
                x = module(x)
                if name in self.target_layers:
                    x.register_hook(self.save_gradient)
                    outputs += [x]
            return outputs, x
    
    class ModelOutputs():
        """ Class for making a forward pass, and getting:
        1. The network output.
        2. Activations from intermeddiate targetted layers.
        3. Gradients from intermeddiate targetted layers. """
    
        def __init__(self, model, feature_module, target_layers):
            self.model = model
            self.feature_module = feature_module
            self.feature_extractor = FeatureExtractor(self.feature_module, target_layers)
    
        def get_gradients(self):
            return self.feature_extractor.gradients
    
        def __call__(self, x):
            target_activations = []
            for name, module in self.model._modules.items():
                if module == self.feature_module:
                    target_activations, x = self.feature_extractor(x)
                elif "avgpool" in name.lower():
                    x = module(x)
                    x = x.view(x.size(0),-1)
                else:
                    if name is 'classifier':
                        x = x.view(x.size(0), -1)
                    x = module(x)
    
            return target_activations, x
    
    class GradCam:
        def __init__(self, model, feature_module, target_layer_names, use_cuda):
            self.model = model
            self.feature_module = feature_module
            self.model.eval()
            self.cuda = use_cuda
            if self.cuda:
                self.model = model.cuda()
    
            self.extractor = ModelOutputs(self.model, self.feature_module, target_layer_names)
    
        def forward(self, input_img):
            return self.model(input_img)
    
        def __call__(self, input_img, target_category=None):
            if self.cuda:
                input_img = input_img.cuda()
    
            features, output = self.extractor(input_img)
    
            if target_category == None:
                target_category = np.argmax(output.cpu().data.numpy())
    
            one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
            one_hot[0][target_category] = 1
            one_hot = torch.from_numpy(one_hot).requires_grad_(True)
            if self.cuda:
                one_hot = one_hot.cuda()
            
            one_hot = torch.sum(one_hot * output)
    
            self.feature_module.zero_grad()
            self.model.zero_grad()
            one_hot.backward(retain_graph=True)
    
            grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()
    
            target = features[-1]
            target = target.cpu().data.numpy()[0, :]
    
            weights = np.mean(grads_val, axis=(2, 3))[0, :]
            cam = np.zeros(target.shape[1:], dtype=np.float32)
    
            for i, w in enumerate(weights):
                cam += w * target[i, :, :]
    
            cam = np.maximum(cam, 0)
            cam = cv2.resize(cam, input_img.shape[2:])
            cam = cam - np.min(cam)
            cam = cam / np.max(cam)
            return cam
    

    把这一段复制到自己的代码中后,可以参考下面的代码逻辑,简单改写自己的代码即可实现可视化(看不懂的话还是看github):

    grad_cam = GradCam(model = model,feature_module = model.features,target_layer_names=['11'],use_cuda=True)
    def draw(ax,grayscale_cam,data):
        heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET)
        heatmap = heatmap + data.detach().cpu().numpy()[0,0].reshape(28,28,1).repeat(3,axis=2)
        heatmap = heatmap / np.max(heatmap)
        ax.imshow(heatmap)
    for data,target in val_loader:
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        # 绘制9张可视化图
        fig = plt.figure(figsize=(12,12))
        for i in range(9):
            d = data[i:i+1]
            grayscale_cam = grad_cam(d)
            ax = fig.add_subplot(3,3,i+1)
            draw(ax,grayscale_cam,d)
        break
    

    输出图像为:

    有问题欢迎联系作者讨论,请多指教。

  • 相关阅读:
    文件格式——gff格式
    文件格式——fastq格式
    Java 8 新特性:1-函数式接口
    10分钟学会JAVA注解(annotation)
    spring MVC 乱码问题
    Tomcat 连接池详解
    DBCP连接池配置参数说明
    spring 事务无效解决方法
    spring mvc 存取值
    使用Criteria 实现两表的左外连接,返回根对象
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/14206894.html
Copyright © 2011-2022 走看看