zoukankan      html  css  js  c++  java
  • 注意力模型CBAM分类-pytorch

    目前因项目需要,将检测模型与图像分类结合,完成项目。因此将CBAM模型代码进行整理,仅仅需要train.py与test.py,可分别对图像训练与分类,为了更好学习代码,本文内容分2块,其一将引用

    他人博客,简单介绍原理;其二根据改写代码,介绍如何使用,训练自己模型及测试图片。论文:CBAM: Convolutional Block Attention Module 

     代码可参考:https://github.com/tangjunjun966/CBAM_PyTorch

    一.基本原理

    Convolutional Block Attention Module (CBAM) 表示卷积模块的注意力机制模块。是一种结合了空间(spatial)和通道(channel)的注意力机制模块。相比于senet只关注通道(channel)的注意力机制可以取得更好的效果。

    基于传统VGG结构的CBAM模块。需要在每个卷积层后面加该模块。

    基于shortcut结构的CBAM模块。例如resnet50,该模块在每个resnet的block后面加该模块。

     

    Channel attention module:

     

    将输入的featuremap,分别经过基于width和height的global max pooling 和global average pooling,然后分别经过MLP。将MLP输出的特征进行基于elementwise的加和操作,再经过sigmoid激活操作,生成最终的channel attention featuremap。将该channel attention featuremap和input featuremap做elementwise乘法操作,生成Spatial attention模块需要的输入特征。

     

    其中,seigema为sigmoid操作,r表示减少率,其中W0后面需要接RELU激活。

    Spatial attention module:

     

    将Channel attention模块输出的特征图作为本模块的输入特征图。首先做一个基于channel的global max pooling 和global average pooling,然后将这2个结果基于channel 做concat操作。然后经过一个卷积操作,降维为1个channel。再经过sigmoid生成spatial attention feature。最后将该feature和该模块的输入feature做乘法,得到最终生成的特征。

     

    其中,seigema为sigmoid操作,7*7表示卷积核的大小,7*7的卷积核比3*3的卷积核效果更好。

    二.代码使用

    复制代码存放文件夹,其格式如下:

    训练代码,已将整理成数据产生,模型产生等,可复制后修改args内参数,可直接调用。

    训练代码如下:

    
    
    from collections import OrderedDict
    import argparse
    import torch.optim as optim
    from torch.optim import lr_scheduler
    from torchvision import transforms, models, datasets
    from torchnet.meter import ClassErrorMeter, ConfusionMeter
    import torch.backends.cudnn as cudnn
    import torch.nn.functional as F
    import traceback
    import os
    import time
    import torch
    import torch.nn as nn
    import math
    import torch.utils.model_zoo as model_zoo
    import sys
    from PIL import Image
    import numpy as np
    
    def load_state_dict(model_dir, is_multi_gpu):
        state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)['state_dict']
        if is_multi_gpu:
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            return new_state_dict
        else:
            return state_dict
    
    
    
    
    
    def parse_parameters():
        parser = argparse.ArgumentParser(description='PyTorch Template')
        parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint (default: None)') # 基本不适用
        parser.add_argument('--debug', action='store_true', dest='debug', help='trainer debug flag') # 不适用
        parser.add_argument('--gpu', default='0', type=str, help='GPU ID Select')  # 多gpu使用:'0,1,2'
        parser.add_argument('--data_root', default='./datasets/', type=str, help='data root') # datasets下面包含train与val文件夹,其中train与val文件夹内存放缺陷文件夹(缺陷图片)具体路径可看代码
        parser.add_argument('--train_file', default='./datasets//train.txt', type=str, help='train file')
        parser.add_argument('--val_file', default='./datasets/val.txt',
                            type=str, help='validation file')
        parser.add_argument('--model', default='resnet50_cbam', type=str, help='model type')
        parser.add_argument('--batch_size', default=4, type=int, help='model train batch size')
        parser.add_argument('--display', action='store_true', dest='display', help='Use TensorboardX to Display')
        parser.add_argument('--classes', default=2, type=int, help='Number of classes')
        parser.add_argument('--work_dir', default='./datasets/work_dir', type=str, help='work directory')
        parser.add_argument('--total_epochs', default=36, type=int, help='total epoch')
    
    
        args = parser.parse_args()
        return args
    
    
    class Logger(object):
        '''Save training process to log file with simple plot function.'''
        def __init__(self, fpath, resume=False):
            self.file = None
            self.resume = resume
            if os.path.isfile(fpath):
                if resume:
                    self.file = open(fpath, 'a')
                else:
                    self.file = open(fpath, 'w')
            else:
                self.file = open(fpath, 'w')
    
        def append(self, target_str):
            if not isinstance(target_str, str):
                try:
                    target_str = str(target_str)
                except:
                    traceback.print_exc()
                else:
                    print(target_str)
                    self.file.write(target_str + '
    ')
                    self.file.flush()
            else:
                print(target_str)
                self.file.write(target_str + '
    ')
                self.file.flush()
    
        def close(self):
            if self.file is not None:
                self.file.close()
    
    
    
    
    class Concat_patch(object):
        """Resize the input PIL Image to the given size.
    
        Args:
            size (sequence or int): Desired output size. If size is a sequence like
                (h, w), output size will be matched to this. If size is an int,
                smaller edge of the image will be matched to this number.
                i.e, if height > width, then image will be rescaled to
                (size * height / width, size)
            interpolation (int, optional): Desired interpolation. Default is
                ``PIL.Image.BILINEAR``
        """
    
        def __init__(self, margin_ratio=(0.25, 0.25)):
            self.margin_ratio = margin_ratio
    
        def __call__(self, img):
            """
            Args:
                img (PIL Image): Image to be scaled.
    
            Returns:
                PIL Image: Rescaled image.
            """
            img = img
            array_img = np.array(img)
            h, w, c = array_img.shape
            h_margin = int(h * self.margin_ratio[0])
            w_margin = int(w * self.margin_ratio[1])
            patches = [array_img[0:h_margin, 0:w_margin, :], array_img[h - h_margin:, 0:w_margin, :],
                       array_img[0:h_margin, w - w_margin:, :], array_img[h - h_margin:, w - w_margin:, :]]
    
            def concat_patches(patches):
                a = np.concatenate(patches[:2], axis=0)
                b = np.concatenate(patches[2:], axis=0)
                c = np.concatenate([a, b], axis=1)
                return c
    
            img = concat_patches(patches)
            img = Image.fromarray(img)
            return img
    
        def __repr__(self):
            interpolate_str = 'reconcat'
            return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
    
    
    def build_dataset(args):
        gpus = args.gpu.split(',')
        data_transforms = {
            'train': transforms.Compose([
                Concat_patch(),
                transforms.Resize((224, 224)),
                # transforms.Resize((320, 320)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                # transforms.RandomRotation(90),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'val': transforms.Compose([
                Concat_patch(),
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        }
        train_datasets = datasets.ImageFolder(os.path.join(args.data_root, 'train'), data_transforms['train'])
        val_datasets = datasets.ImageFolder(os.path.join(args.data_root, 'val'), data_transforms['val'])
        # sampler = torch.utils.data.WeightedRandomSampler(weights=[1, 1], num_samples=len(train_datasets), replacement=True)
        train_dataloaders = torch.utils.data.DataLoader(train_datasets, batch_size=args.batch_size * len(gpus),
                                                        shuffle=True, num_workers=4)
        val_dataloaders = torch.utils.data.DataLoader(val_datasets, batch_size=4, shuffle=False, num_workers=4)
    
        return train_dataloaders,val_dataloaders
    
    def build_model(args):
        if 'resnet50' == args.model:
            my_model = resnet50(pretrained=False, num_classes=args.classes)
        elif 'resnet50_cbam' == args.model:
            my_model = resnet50_cbam(pretrained=False, num_classes=args.classes)
        elif 'resnet101' == args.model:
            my_model = models.resnet101(pretrained=False, num_classes=args.classes)
        elif 'resnet18' == args.model:
            my_model = models.resnet18(pretrained=False, num_classes=args.classes)
        elif 'resnet18_cbam' == args.model:
            my_model = resnet18_cbam(pretrained=True, num_classes=args.classes)
        else:
            raise ModuleNotFoundError
    
    
        return my_model
    
    
    
    def build_optimezer(model):
        loss_fn = [nn.CrossEntropyLoss(weight=torch.Tensor([0.5, 5]).cuda())]  # 不放到其它cuda上,是因为model输出结果在cuda0上处理
        # loss_fn = [nn.CrossEntropyLoss()]
        optimizer = optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=1e-4)
        lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=[16, 24, 32], gamma=0.1)  # 按照epoch更新lr
        return loss_fn,optimizer,lr_schedule
    
    
    
    class Trainer():
        def __init__(self, model, model_type, loss_fn, optimizer, lr_schedule, log_batchs, is_use_cuda, train_data_loader, 
                     valid_data_loader=None, metric=None, start_epoch=0, num_epochs=25, is_debug=False, logger=None,
                      workdir='.'):
            self.model = model
            self.model_type = model_type
            self.loss_fn = loss_fn
            self.optimizer = optimizer
            self.lr_schedule = lr_schedule
            self.log_batchs = log_batchs
            self.is_use_cuda = is_use_cuda
            self.train_data_loader = train_data_loader
            self.valid_data_loader = valid_data_loader
            self.metric = metric
            self.start_epoch = start_epoch
            self.num_epochs = num_epochs
            self.is_debug = is_debug
    
            self.cur_epoch = start_epoch
            self.best_acc = 0.
            self.best_loss = sys.float_info.max
            self.logger = logger
    
            self.workdir = workdir
    
        def fit(self):
            for epoch in range(0, self.start_epoch):
                self.lr_schedule.step()
    
            for epoch in range(self.start_epoch, self.num_epochs):
                self.logger.append('Epoch {}/{}'.format(epoch, self.num_epochs - 1))
                self.logger.append('-' * 60)
                self.cur_epoch = epoch
                self.lr_schedule.step() # 实际更新scheduler.last_epoch,且当该值到milestones,则改变学习率
                if self.is_debug:
                    self._dump_infos()
                self._train()
                self._valid()
                self._save_best_model()
                print()
    
        def _dump_infos(self):
            self.logger.append('---------------------Current Parameters---------------------')
            self.logger.append('is use GPU: ' + ('True' if self.is_use_cuda else 'False'))
            self.logger.append('lr: %f' % (self.lr_schedule.get_lr()[0]))
            self.logger.append('model_type: %s' % (self.model_type))
            self.logger.append('current epoch: %d' % (self.cur_epoch))
            self.logger.append('best accuracy: %f' % (self.best_acc))
            self.logger.append('best loss: %f' % (self.best_loss))
            self.logger.append('------------------------------------------------------------')
    
        def _train(self):
            self.model.train()  # Set model to training mode
            losses = []
            if self.metric is not None:
                self.metric[0].reset()
                # self.metric[1].reset()
    
            for i, (inputs, labels) in enumerate(self.train_data_loader):  # Notice
                if self.is_use_cuda:
                    inputs, labels = inputs.cuda(), labels.cuda()
                    labels = labels.squeeze()
                else:
                    labels = labels.squeeze()
    
                self.optimizer.zero_grad()  # 清理梯度
                outputs = self.model(inputs)  # Notice
                loss = self.loss_fn[0](outputs, labels)
                if self.metric is not None:
                    prob = F.softmax(outputs, dim=1).data.cpu()
                    self.metric[0].add(prob, labels.data.cpu())
    
                    #one_hot = torch.zeros(prob.shape[0], prob.shape[1]).scatter_(1, labels.cpu(), 1)
                    # self.metric[1].add(prob, labels.data.cpu())
                loss.backward()
                self.optimizer.step()
    
                losses.append(loss.item())  # Notice
                if 0 == i % self.log_batchs or (i == len(self.train_data_loader) - 1):
                    local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
                    batch_mean_loss = np.mean(losses)
                    print_str = '[%s]	Training Batch[%d/%d]	 Class Loss: %.4f	' 
                                % (local_time_str, i, len(self.train_data_loader) - 1, batch_mean_loss)
                    if i == len(self.train_data_loader) - 1 and self.metric is not None:
                        confusion = self.metric[0].value()
                        print(confusion)
                        # top1_acc_score = self.metric[0].value()[0]
                        # top3_acc_score = self.metric[0].value()[1]
                        # print_str += '@Top-1 Score: %.4f	' % (top1_acc_score)
                        # print_str += '@Top-3 Score: %.4f	' % (top3_acc_score)
                        # print(self.metric[1].value())
                    self.logger.append(print_str)
    
    
    
        def _valid(self):
            self.model.eval()
            losses = []
            acc_rate = 0.
            if self.metric is not None:
                self.metric[0].reset()
    
            with torch.no_grad():  # Notice
                for i, (inputs, labels) in enumerate(self.valid_data_loader):
                    if self.is_use_cuda:
                        inputs, labels = inputs.cuda(), labels.cuda()
                        labels = labels.squeeze()
                    else:
                        labels = labels.squeeze()
    
                    if len(labels.shape) == 0:
                        labels = labels.view(-1)
                    outputs = self.model(inputs)  # Notice
                    loss = self.loss_fn[0](outputs, labels)
    
                    if self.metric is not None:
                        prob = F.softmax(outputs, dim=1).data.cpu()
                        self.metric[0].add(prob, labels.data.cpu())
                    losses.append(loss.item())
    
            local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
            # self.logger.append(losses)
            batch_mean_loss = np.mean(losses)
            print_str = '[%s]	Validation: 	 Class Loss: %.4f	' 
                        % (local_time_str, batch_mean_loss)
            if self.metric is not None:
                confusion = self.metric[0].value()
                print(confusion)
                # top1_acc_score = self.metric[0].value()[0]
                # top3_acc_score = self.metric[0].value()[1]
                # print_str += '@Top-1 Score: %.4f	' % (top1_acc_score)
                # print_str += '@Top-3 Score: %.4f	' % (top3_acc_score)
            self.logger.append(print_str)
            # if top1_acc_score >= self.best_acc:
            #     self.best_acc = top1_acc_score
            #     self.best_loss = batch_mean_loss
    
        def _save_best_model(self):
            # Save Model
            self.logger.append('Saving Model...')
            state = {
                'state_dict': self.model.state_dict(),
                'best_acc': self.best_acc,
                'cur_epoch': self.cur_epoch,
                'num_epochs': self.num_epochs
            }
            if not os.path.isdir(os.path.join(self.workdir, 'checkpoint/') + self.model_type):
                os.makedirs(os.path.join(self.workdir, 'checkpoint/') + self.model_type)
            torch.save(state,
                       os.path.join(self.workdir, 'checkpoint/') + self.model_type + '/Models' + '_epoch_%d' % self.cur_epoch + '.ckpt')  # Notice
    
    
    # 构建网络
    
    
    
    
    
    model_urls = {
        'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
        'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
        'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
        'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
        'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    }
    def conv3x3(in_planes, out_planes, stride=1):
        "3x3 convolution with padding"
        return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                         padding=1, bias=False)
    class ChannelAttention(nn.Module):
        def __init__(self, in_planes, ratio=16):
            super(ChannelAttention, self).__init__()
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.max_pool = nn.AdaptiveMaxPool2d(1)
    
            self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
            self.relu1 = nn.ReLU()
            self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
    
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
            max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
            out = avg_out + max_out
            return self.sigmoid(out)
    
    class SpatialAttention(nn.Module):
        def __init__(self, kernel_size=7):
            super(SpatialAttention, self).__init__()
    
            assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
            padding = 3 if kernel_size == 7 else 1
    
            self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
            self.sigmoid = nn.Sigmoid()
    
        def forward(self, x):
            avg_out = torch.mean(x, dim=1, keepdim=True)
            max_out, _ = torch.max(x, dim=1, keepdim=True)  # _索引,维度不变
            x = torch.cat([avg_out, max_out], dim=1)
            x = self.conv1(x)
            return self.sigmoid(x)
    
    class BasicBlock(nn.Module):
        expansion = 1
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(BasicBlock, self).__init__()
            self.conv1 = conv3x3(inplanes, planes, stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = conv3x3(planes, planes)
            self.bn2 = nn.BatchNorm2d(planes)
    
            self.ca = ChannelAttention(planes)
            self.sa = SpatialAttention()
    
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
    
            out = self.ca(out) * out
            out = self.sa(out) * out
    
            if self.downsample is not None:
                residual = self.downsample(x)
    
            out += residual
            out = self.relu(out)
    
            return out
    
    class Bottleneck_CBAM(nn.Module):
        expansion = 4
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(Bottleneck_CBAM, self).__init__()
            self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                   padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(planes)
            self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
            self.bn3 = nn.BatchNorm2d(planes * 4)
            self.relu = nn.ReLU(inplace=True)
    
            self.ca = ChannelAttention(planes * 4)
            self.sa = SpatialAttention()
    
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
    
            out = self.conv3(out)
            out = self.bn3(out)
    
            out = self.ca(out) * out
            out = self.sa(out) * out
    
            if self.downsample is not None:
                residual = self.downsample(x)
    
            out += residual
            out = self.relu(out)
    
            return out
    
    class Bottleneck(nn.Module):
        expansion = 4
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(Bottleneck, self).__init__()
            self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                   padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(planes)
            self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
            self.bn3 = nn.BatchNorm2d(planes * 4)
            self.relu = nn.ReLU(inplace=True)
    
            # self.ca = ChannelAttention(planes * 4)
            # self.sa = SpatialAttention()
    
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
    
            out = self.conv3(out)
            out = self.bn3(out)
    
            # out = self.ca(out) * out
            # out = self.sa(out) * out
    
            if self.downsample is not None:
                residual = self.downsample(x)
    
            out += residual
            out = self.relu(out)
    
            return out
    class ResNet(nn.Module):
    
        def __init__(self, block, layers, num_classes=23):
            self.inplanes = 64
            super(ResNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                                   bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
    
            self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
    
            self.avgpool = nn.AvgPool2d(7, stride=1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
    
        def _make_layer(self, block, planes, blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes * block.expansion:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
    
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes * block.expansion
            for i in range(1, blocks):
                layers.append(block(self.inplanes, planes))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
    
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
    
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
    
            return x
    def resnet18_cbam(pretrained=False, **kwargs):
        """Constructs a ResNet-18 model.
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
        if pretrained:
            pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
            now_state_dict = model.state_dict()
            now_state_dict.update(pretrained_state_dict)
            now_state_dict.pop('fc.weight')
            now_state_dict.pop('fc.bias')
            model.load_state_dict(now_state_dict, strict=False)
        return model
    
    def resnet34_cbam(pretrained=False, **kwargs):
        """Constructs a ResNet-34 model.
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
        if pretrained:
            pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
            now_state_dict = model.state_dict()
            now_state_dict.update(pretrained_state_dict)
            model.load_state_dict(now_state_dict)
        return model
    
    def resnet50_cbam(pretrained=False, **kwargs):
        """Constructs a ResNet-50 model.
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = ResNet(Bottleneck_CBAM, [3, 4, 6, 3], **kwargs)
        if pretrained:
            pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
            now_state_dict = model.state_dict()
            now_state_dict.update(pretrained_state_dict)
            model.load_state_dict(now_state_dict)
        return model
    
    def resnet50(pretrained=False, **kwargs):
        """Constructs a ResNet-50 model.
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
        if pretrained:
            pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
            now_state_dict = model.state_dict()
            now_state_dict.update(pretrained_state_dict)
            model.load_state_dict(now_state_dict)
        return model
    
    def resnet101_cbam(pretrained=False, **kwargs):
        """Constructs a ResNet-101 model.
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = ResNet(Bottleneck_CBAM, [3, 4, 23, 3], **kwargs)
        if pretrained:
            pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
            now_state_dict = model.state_dict()
            now_state_dict.update(pretrained_state_dict)
            model.load_state_dict(now_state_dict)
        return model
    
    def resnet152_cbam(pretrained=False, **kwargs):
        """Constructs a ResNet-152 model.
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = ResNet(Bottleneck_CBAM, [3, 8, 36, 3], **kwargs)
        if pretrained:
            pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
            now_state_dict = model.state_dict()
            now_state_dict.update(pretrained_state_dict)
            model.load_state_dict(now_state_dict)
        return model
    
    
    def train():
        args=parse_parameters()
        logger = Logger('./' + args.model + '.log') if len(args.resume)==0 else Logger('./' + args.model + '.log', True)
        logger.append(vars(args))
    
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
        is_use_cuda = torch.cuda.is_available()
        cudnn.benchmark = True
    
        train_dataloaders, val_dataloaders = build_dataset(args)
        model=build_model(args)
        loss_fn, optimizer, lr_schedule = build_optimezer(model)
    
    
        if is_use_cuda and 1 == len(args.gpu.split(',')):
            model = model.cuda()
        elif is_use_cuda and 1 < len(args.gpu.split(',')):
            model = nn.DataParallel(model.cuda())   # 将模型my_model.cuda() 缓存放在cuda 0 上
    
    
        metric = [ConfusionMeter(2)]
        start_epoch = 0
        my_trainer = Trainer(model, args.model, loss_fn, optimizer, lr_schedule, 10, is_use_cuda, train_dataloaders,
                             val_dataloaders, metric, start_epoch, args.total_epochs, args.debug, logger,  args.work_dir)
        my_trainer.fit()
        logger.append('Optimize Done!')
    
    
    
    
    
    
    
    
    
    if __name__ == '__main__':
    
    
        train()


    测试代码调用模型依附训练代码,因此需要有训练代码与测试代码同文件,方可调用。
    测试代码如下:


    from collections import OrderedDict
    from PIL import Image
    import torch
    import torch.nn.functional as F
    from torch.autograd import Variable
    from torchvision import transforms
    import numpy as np
    
    from train_new import resnet50_cbam
    
    def init_cls_model(checkpoint_path, is_multi_gpu=False, classes=2):
    
        my_model = resnet50_cbam(num_classes=classes)
        state_dict = torch.load(checkpoint_path)['state_dict']
        if is_multi_gpu:
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            my_model.load_state_dict(new_state_dict)
        else:
            my_model.load_state_dict(state_dict)
    
        my_model = my_model.cuda()
        my_model.eval()
    
        return my_model
    
    class Concat_patch(object):  # 切图,实际可以不用
        """Resize the input PIL Image to the given size.
        Args:
            size (sequence or int): Desired output size. If size is a sequence like
                (h, w), output size will be matched to this. If size is an int,
                smaller edge of the image will be matched to this number.
                i.e, if height > width, then image will be rescaled to
                (size * height / width, size)
            interpolation (int, optional): Desired interpolation. Default is
                ``PIL.Image.BILINEAR``
        """
    
        def __init__(self, margin_ratio=(0.25, 0.25)):
            self.margin_ratio = margin_ratio
    
        def __call__(self, img):
            """
            Args:
                img (PIL Image): Image to be scaled.
    
            Returns:
                PIL Image: Rescaled image.
            """
            img = img
            array_img = np.array(img)
            h, w, c = array_img.shape
            h_margin = int(h * self.margin_ratio[0])
            w_margin = int(w * self.margin_ratio[1])
            patches = [array_img[0:h_margin, 0:w_margin, :], array_img[h - h_margin:, 0:w_margin, :],
                       array_img[0:h_margin, w - w_margin:, :], array_img[h - h_margin:, w - w_margin:, :]]
    
            def concat_patches(patches):
                a = np.concatenate(patches[:2], axis=0)
                b = np.concatenate(patches[2:], axis=0)
                c = np.concatenate([a, b], axis=1)
                return c
    
            img = concat_patches(patches)
            img = Image.fromarray(img)
            return img
    
        def __repr__(self):
            interpolate_str = 'reconcat'
            return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
    
    def cls_judge(img_path, model, img_size=224):
        FALSE_NAME = 'FALSE'
        NG_NAME = 'NG'
    
        CLS_NAME = [FALSE_NAME, NG_NAME]
        data_transform = transforms.Compose([
            Concat_patch(),
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    
    
        file_path = img_path
    
        with torch.no_grad():
            img_tensor = data_transform(Image.open(file_path).convert('RGB')).unsqueeze(0)
            img_tensor = Variable(img_tensor.cuda(), volatile=True)
            output = F.softmax(model(img_tensor), dim=1).cpu().numpy()
        # defect_prob = round(output.data[0, 1], 6)
        pred = np.argmax(output)
        pred = CLS_NAME[pred]
    
        score = np.max(output)
        if pred == FALSE_NAME:
            score = 0
        if score <= 0.85 and pred == NG_NAME:
            pred = FALSE_NAME
            score = 0
    
        return pred, score
    
    
    if __name__ == '__main__':
        model_path=r'E:code_tjCBAM_PyTorchdatasetswork_dircheckpoint
    esnet50-cbamModels_epoch_0.ckpt'
        img=r'E:code_tjCBAM_PyTorchdatasetsvalv06W0C2P0206A0108_WHITE_20210125.jpg'
        model=init_cls_model(model_path, is_multi_gpu=False, classes=2)
        pre=cls_judge(img, model, img_size=224)
        print(pre)
     


    参考博客:https://blog.csdn.net/qq_14845119/article/details/81393127

    处理算法通用的辅助的code,如读取txt文件,读取xml文件,将xml文件转换成txt文件,读取json文件等
  • 相关阅读:
    Vue-router 报NavigationDuplicated的可能解决方案
    go 数据类型转换
    在vscode 之中使用 GO MOD
    javascript格式化
    Mac node-sass 安装失败“v8::String::Utf8Value”
    Django 使用gunicorn 和 supervisord部署
    关于windows上的账号(权限)切换
    python中的global关键字
    暂时性的小总结
    windwos 安装下kafka的安装使用
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/14868891.html
Copyright © 2011-2022 走看看