zoukankan      html  css  js  c++  java
  • Saliency map实现

    import PIL, torch, torchvision
    import matplotlib.pyplot as plt
    import sys
    import pandas as pd
    
    # 标准化
    def normalize(image):
         return (image - image.min()) / (image.max() - image.min())
    
    
    def show_saliency_map(img_path, model, size=100, cmap=plt.cm.hot):
    #     evaluate模式
         model.eval()
         
    #     图像变换
         aug1 = torchvision.transforms.Compose(
             [torchvision.transforms.Resize((size, size)),
              torchvision.transforms.ToTensor()])
         aug2 = torchvision.transforms.Resize((size, size))
         aug3 = torchvision.transforms.ToPILImage()
    
    #     读取一张图片
         img = PIL.Image.open(img_path)
         img = img.convert("RGB")
    #     变换
         timg = aug1(img).view(1, 3, size, size)
    #     梯度
         timg.requires_grad = True
    
    #     正向传播得到output
         output = model(timg)
    #     获取预测概率最大的index
         timg_class = output.argmax(dim=1).item()
    
    #     1000类dict
         pd_data = pd.read_csv('./1000class_dict.csv')
         
         pd_data_en = pd_data.iloc[:, 3]
         class_index_en = pd_data_en.to_dict()
         
         pd_data_zh = pd_data.iloc[:, 2]
         class_index_zh = pd_data_zh.to_dict()
         
         print(class_index_zh[timg_class],class_index_en[timg_class])
    
    #     找到output的对应fc输出单元
         s = output[0, timg_class]
    #     反向传播求此单元梯度
         s.backward()
    
        with torch.no_grad():
    #         得到了梯度
             grad = timg.grad.data[0]
    #         对梯度图处理,取绝对值,求像素通道最大值
             graph = torch.max(torch.abs(grad), dim=0)[0]  # [0]是max_value  [1]是max_index
             lambd = 0.1
    #         paper中的方法
             saliency_map_gray = (graph - lambd * (torch.norm(timg, 2) ** 2).item()).numpy()
             
    #         直接梯度求绝对值
             saliency_map_rgb = timg.grad.abs().cpu()
    #         将每个通道归一化
             saliency_map_rgb = torch.stack([normalize(item) for item in saliency_map_rgb])
    
        fig, ax = plt.subplots(1, 3)
         raw_img = aug2(img)
         ax[0].imshow(raw_img)
         ax[0].set_title(class_index_en[timg_class])
         
         rgb_saliency = aug3(saliency_map_rgb.view(3, size, size))
         ax[1].imshow(rgb_saliency)
         ax[1].set_title('RGB map')
         ax[2].imshow(saliency_map_gray, cmap=cmap)
         ax[2].set_title('gray map')
         plt.show()
    
    img = './panda.png'
    model = torchvision.models.resnet18(pretrained=True)
    show_saliency_map(img, model, size = 224)

    参考:Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps , https://arxiv.org/abs/1312.6034

  • 相关阅读:
    Java_Habse_add
    Java_Habse_shell
    android Studio 出现:Unable to resolve dependency for ':app@debug/compileClasspath'
    微信小程序云函数中有以下未安装的依赖,如果未安装即全量上传
    Bittorrent Protocol Specification v1.0 中文
    BT客户端实现 Peer协议设计
    NGINX 配置 SSL 双向认证
    openssl、x509、crt、cer、key、csr、ssl、tls 这些都是什么鬼?
    ssl双向认证和单向认证原理
    网络服务器之HTTPS服务
  • 原文地址:https://www.cnblogs.com/mydrizzle/p/13977924.html
Copyright © 2011-2022 走看看