zoukankan      html  css  js  c++  java
  • Saliency map实现

    import PIL, torch, torchvision
    import matplotlib.pyplot as plt
    import sys
    import pandas as pd
    
    # 标准化
    def normalize(image):
         return (image - image.min()) / (image.max() - image.min())
    
    
    def show_saliency_map(img_path, model, size=100, cmap=plt.cm.hot):
    #     evaluate模式
         model.eval()
         
    #     图像变换
         aug1 = torchvision.transforms.Compose(
             [torchvision.transforms.Resize((size, size)),
              torchvision.transforms.ToTensor()])
         aug2 = torchvision.transforms.Resize((size, size))
         aug3 = torchvision.transforms.ToPILImage()
    
    #     读取一张图片
         img = PIL.Image.open(img_path)
         img = img.convert("RGB")
    #     变换
         timg = aug1(img).view(1, 3, size, size)
    #     梯度
         timg.requires_grad = True
    
    #     正向传播得到output
         output = model(timg)
    #     获取预测概率最大的index
         timg_class = output.argmax(dim=1).item()
    
    #     1000类dict
         pd_data = pd.read_csv('./1000class_dict.csv')
         
         pd_data_en = pd_data.iloc[:, 3]
         class_index_en = pd_data_en.to_dict()
         
         pd_data_zh = pd_data.iloc[:, 2]
         class_index_zh = pd_data_zh.to_dict()
         
         print(class_index_zh[timg_class],class_index_en[timg_class])
    
    #     找到output的对应fc输出单元
         s = output[0, timg_class]
    #     反向传播求此单元梯度
         s.backward()
    
        with torch.no_grad():
    #         得到了梯度
             grad = timg.grad.data[0]
    #         对梯度图处理,取绝对值,求像素通道最大值
             graph = torch.max(torch.abs(grad), dim=0)[0]  # [0]是max_value  [1]是max_index
             lambd = 0.1
    #         paper中的方法
             saliency_map_gray = (graph - lambd * (torch.norm(timg, 2) ** 2).item()).numpy()
             
    #         直接梯度求绝对值
             saliency_map_rgb = timg.grad.abs().cpu()
    #         将每个通道归一化
             saliency_map_rgb = torch.stack([normalize(item) for item in saliency_map_rgb])
    
        fig, ax = plt.subplots(1, 3)
         raw_img = aug2(img)
         ax[0].imshow(raw_img)
         ax[0].set_title(class_index_en[timg_class])
         
         rgb_saliency = aug3(saliency_map_rgb.view(3, size, size))
         ax[1].imshow(rgb_saliency)
         ax[1].set_title('RGB map')
         ax[2].imshow(saliency_map_gray, cmap=cmap)
         ax[2].set_title('gray map')
         plt.show()
    
    img = './panda.png'
    model = torchvision.models.resnet18(pretrained=True)
    show_saliency_map(img, model, size = 224)

    参考:Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps , https://arxiv.org/abs/1312.6034

  • 相关阅读:
    python web 框架的基本逻辑练习题
    jQuery 自定义方法(扩展方法)
    jQuery 的动画效果图片----隐藏打开方法
    jQuery 小练习-拖拉画面
    用jQuery来绑定事件的3种方法和区别
    css用hover制作下拉菜单
    巧用hover改变css样式和背景
    mpvue中按需引入echarts
    webpack配置css浏览器前缀
    Vue中使用Sass全局变量
  • 原文地址:https://www.cnblogs.com/mydrizzle/p/13977924.html
Copyright © 2011-2022 走看看