zoukankan      html  css  js  c++  java
  • PyTorch模型训练特征图可视化

    当你想了解网络学习了什么的时候,一般都想着将特征图可视化,列如backbone中第一层输出特征图[batch,channel,hight,width],该如何将其特征图呈现呢?

    如果你也在为此烦恼,本文将简单介绍可视化内容,并使用代码将其呈现,可视化,本文内容如下:

    1.特征图可视化,这种方法是最简单,输入一张照片,然后把网络中间某层的输出的特征图按通道作为图片进行可视化展示即可。

    2.特征图可视化代码如下:

    def featuremap_visual(feature, 
                          out_dir=None,  # 特征图保存路径文件
                          save_feature=True,  # 是否以图片形式保存特征图
                          show_feature=True,  # 是否使用plt显示特征图
                          feature_title=None,  # 特征图名字,默认以shape作为title
                          num_ch=-1,  # 显示特征图前几个通道,-1 or None 都显示
                          nrow=8,  # 每行显示多少个特征图通道
                          padding=10,  # 特征图之间间隔多少像素值
                          pad_value=1  # 特征图之间的间隔像素
                          ):
        import matplotlib.pylab as plt
        import torchvision
        import os
        # feature = feature.detach().cpu()
        b, c, h, w = feature.shape
        feature = feature[0]
        feature = feature.unsqueeze(1)
    
        if c > num_ch > 0:
            feature = feature[:num_ch]
    
        img = torchvision.utils.make_grid(feature, nrow=nrow, padding=padding, pad_value=pad_value)
        img = img.detach().cpu()
        img = img.numpy()
        images = img.transpose((1, 2, 0))
    
        # title = str(images.shape) if feature_title is None else str(feature_title)
        title = str('hwc-') + str(h) + '-' + str(w) + '-' + str(c) if feature_title is None else str(feature_title)
    
        plt.title(title)
        plt.imshow(images)
        if save_feature:
            # root=r'C:UsersAdministratorDesktopCODE_TJ123'
            # plt.savefig(os.path.join(root,'1.jpg'))
            out_root = title + '.jpg' if out_dir == '' or out_dir is None else os.path.join(out_dir, title + '.jpg')
            plt.savefig(out_root)
    
    
    
        if show_feature:        plt.show()

    3.结合resnet网络整体可视化(主要将其featuremap_visual函数插入forward中,即可),整体代码如下:

    resnet网络结构在我博客:https://www.cnblogs.com/tangjunjun/p/14947868.html

    """
    @author: tangjun
    @contact: 511026664@qq.com
    @time: 2020/12/7 22:48
    @desc: 残差ackbone改写,用于构建特征提取模块
    """
    
    import torch.nn as nn
    import torch
    from collections import OrderedDict
    
    
    def Conv(in_planes, out_planes, **kwargs):
        "3x3 convolution with padding"
        padding = kwargs.get('padding', 1)
        bias = kwargs.get('bias', False)
        stride = kwargs.get('stride', 1)
        kernel_size = kwargs.get('kernel_size', 3)
        out = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        return out
    
    
    class BasicBlock(nn.Module):
        expansion = 1
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(BasicBlock, self).__init__()
            self.conv1 = Conv(inplanes, planes, stride=stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = Conv(planes, planes)
            self.bn2 = nn.BatchNorm2d(planes)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
    
            if self.downsample is not None:
                residual = self.downsample(x)
    
            out += residual
            out = self.relu(out)
    
            return out
    
    
    class Bottleneck(nn.Module):
        expansion = 4
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(Bottleneck, self).__init__()
            self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                   padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(planes)
            self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
            self.bn3 = nn.BatchNorm2d(planes * 4)
            self.relu = nn.ReLU(inplace=True)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
    
            out = self.conv3(out)
            out = self.bn3(out)
    
            if self.downsample is not None:
                residual = self.downsample(x)
    
            out += residual
            out = self.relu(out)
    
            return out
    
    
    class Resnet(nn.Module):
        arch_settings = {
            18: (BasicBlock, (2, 2, 2, 2)),
            34: (BasicBlock, (3, 4, 6, 3)),
            50: (Bottleneck, (3, 4, 6, 3)),
            101: (Bottleneck, (3, 4, 23, 3)),
            152: (Bottleneck, (3, 8, 36, 3))
        }
    
        def __init__(self,
                     depth=50,
                     in_channels=None,
                     pretrained=None,
                     frozen_stages=-1
                     # num_classes=None
                     ):
            super(Resnet, self).__init__()
            self.inplanes = 64
            self.inchannels = in_channels if in_channels is not None else 3  # 输入通道
            # self.num_classes=num_classes
            self.block, layers = self.arch_settings[depth]
            self.frozen_stages = frozen_stages
            self.conv1 = nn.Conv2d(self.inchannels, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    
            self.layer1 = self._make_layer(self.block, 64, layers[0], stride=1)
            self.layer2 = self._make_layer(self.block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(self.block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(self.block, 512, layers[3], stride=2)
    
            # self.avgpool = nn.AvgPool2d(7)
            # self.fc = nn.Linear(512 * self.block.expansion, self.num_classes)
            self._freeze_stages()  # 冻结函数
            if pretrained is not None:
                self.init_weights(pretrained=pretrained)
    
        def _freeze_stages(self):
            if self.frozen_stages >= 0:
                self.norm1.eval()
                for m in [self.conv1, self.norm1]:
                    for param in m.parameters():
                        param.requires_grad = False
            for i in range(1, self.frozen_stages + 1):
                m = getattr(self, 'layer{}'.format(i))
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False
    
        def init_weights(self, pretrained=None):
            if isinstance(pretrained, str):
                self.load_checkpoint(pretrained)
            elif pretrained is None:
                for m in self.modules():
                    if isinstance(m, nn.Conv2d):
                        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='relu')
                        if hasattr(m, 'bias') and m.bias is not None:  # m包含该属性且m.bias非None # hasattr(对象,属性)表示对象是否包含该属性
                            nn.init.constant_(m.bias, 0)
    
                    elif isinstance(m, nn.BatchNorm2d):
                        m.weight.data.fill_(1)
                        m.bias.data.zero_()
    
        def load_checkpoint(self, pretrained):
    
            checkpoint = torch.load(pretrained)
            if isinstance(checkpoint, OrderedDict):
                state_dict = checkpoint
            elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
    
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    
            unexpected_keys = []  # 保存checkpoint不在module中的key
            model_state = self.state_dict()  # 模型变量
    
            for name, param in state_dict.items():  # 循环遍历pretrained的权重
                if name not in model_state:
                    unexpected_keys.append(name)
                    continue
                if isinstance(param, torch.nn.Parameter):
                    # backwards compatibility for serialized parameters
                    param = param.data
    
                try:
                    model_state[name].copy_(param)  # 试图赋值给模型
                except Exception:
                    raise RuntimeError(
                        'While copying the parameter named {}, '
                        'whose dimensions in the model are {} not equal '
                        'whose dimensions in the checkpoint are {}.'.format(
                            name, model_state[name].size(), param.size()))
            missing_keys = set(model_state.keys()) - set(state_dict.keys())
            print('missing_keys:', missing_keys)
    
        def _make_layer(self, block, planes, num_blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes * block.expansion:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
    
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes * block.expansion
            for i in range(1, num_blocks):
                layers.append(block(self.inplanes, planes))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            outs = []
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
    
            x = self.layer1(x)
            outs.append(x)
            featuremap_visual(x)
    
            x = self.layer2(x)
            outs.append(x)
    
            featuremap_visual(x)
    
            x = self.layer3(x)
            outs.append(x)
            featuremap_visual(x)
    
            x = self.layer4(x)
            outs.append(x)
    
            # x = self.avgpool(x)
            # x = x.view(x.size(0), -1)
            # x = self.fc(x)
    
            return tuple(outs)
    
    
    def featuremap_visual(feature,
                          out_dir=None,  # 特征图保存路径文件
                          save_feature=True,  # 是否以图片形式保存特征图
                          show_feature=True,  # 是否使用plt显示特征图
                          feature_title=None,  # 特征图名字,默认以shape作为title
                          num_ch=-1,  # 显示特征图前几个通道,-1 or None 都显示
                          nrow=8,  # 每行显示多少个特征图通道
                          padding=10,  # 特征图之间间隔多少像素值
                          pad_value=1  # 特征图之间的间隔像素
                          ):
        import matplotlib.pylab as plt
        import torchvision
        import os
        # feature = feature.detach().cpu()
        b, c, h, w = feature.shape
        feature = feature[0]
        feature = feature.unsqueeze(1)
    
        if c > num_ch > 0:
            feature = feature[:num_ch]
    
        img = torchvision.utils.make_grid(feature, nrow=nrow, padding=padding, pad_value=pad_value)
        img = img.detach().cpu()
        img = img.numpy()
        images = img.transpose((1, 2, 0))
    
        # title = str(images.shape) if feature_title is None else str(feature_title)
        title = str('hwc-') + str(h) + '-' + str(w) + '-' + str(c) if feature_title is None else str(feature_title)
    
        plt.title(title)
        plt.imshow(images)
        if save_feature:
            # root=r'C:UsersAdministratorDesktopCODE_TJ123'
            # plt.savefig(os.path.join(root,'1.jpg'))
            out_root = title + '.jpg' if out_dir == '' or out_dir is None else os.path.join(out_dir, title + '.jpg')
            plt.savefig(out_root)
    
    
    
        if show_feature:        plt.show()
    
    
    import cv2
    import numpy as np
    
    
    def imnormalize(img,
                    mean=[123.675, 116.28, 103.53],
                    std=[58.395, 57.12, 57.375],
                    to_rgb=True
                    ):
        if to_rgb:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32)
        return (img - mean) / std
    
    
    if __name__ == '__main__':
        import matplotlib.pylab as plt
    
    
    
        img = cv2.imread('1.jpg')  # 读取图片
    
        img = imnormalize(img)
        img = torch.from_numpy(img)
    
        img = torch.unsqueeze(img, 0)
        img = img.permute(0, 3, 1, 2)
        img = torch.tensor(img, dtype=torch.float32)
        img = img.to('cuda:0')
    
        model = Resnet(depth=50)
        model.init_weights(pretrained='./resnet50.pth')  # 可以使用,也可以注释
        model = model.cuda()
        out = model(img)

    结果如下:

    参考:https://zhuanlan.zhihu.com/p/189898699

    处理算法通用的辅助的code,如读取txt文件,读取xml文件,将xml文件转换成txt文件,读取json文件等
  • 相关阅读:
    互动留言赠书:《Oracle高性能系统实战大全》
    超融合硬件损坏导致Oracle RAC异常恢复实录
    架构师不得不了解的硬件知识
    加班做的可视化被老板嫌弃,是因为你不会用这些数据工具
    iOS开发之七:常用控件--UISlider、UISegmentedControl、UIPageControl的使用
    iOS开发之六:常用控件--UIImageView的使用
    iOS开发之五:常用控件--UITextField的使用
    Objective-C实现常用的4种排序算法
    C语言实现4种常用排序
    面试常用的4种数组排序
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/15203916.html
Copyright © 2011-2022 走看看