zoukankan      html  css  js  c++  java
  • 行人重识别(ReID) ——基于Person_reID_baseline_pytorch修改业务流程

    下载Person_reID_baseline_pytorch地址:https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial
    下载Market1501数据集:http://www.liangzheng.org/Project/project_reid.html
    Market1501数据集结构:

    ├── Market/
    │   ├── bounding_box_test/          /* Files for testing (candidate images pool)
    │   ├── bounding_box_train/         /* Files for training 
    │   ├── gt_bbox/                    /* We do not use it 
    │   ├── gt_query/                   /* Files for multiple query testing 
    │   ├── query/                      /* Files for testing (query images)
    │   ├── readme.txt
    

    修改--test_dir路径,执行python prepare.py之后的数据集结构:

    ├── Market/
    │   ├── bounding_box_test/          /* Files for testing (candidate images pool)
    │   ├── bounding_box_train/         /* Files for training 
    │   ├── gt_bbox/                    /* We do not use it 
    │   ├── gt_query/                   /* Files for multiple query testing 
    │   ├── query/                      /* Files for testing (query images)
    │   ├── readme.txt
    │   ├── pytorch/
    │       ├── train/                   /* train 
    │           ├── 0002
    |           ├── 0007
    |           ...
    │       ├── val/                     /* val
    │       ├── train_all/               /* train+val      
    │       ├── query/                   /* query files  
    │       ├── gallery/                 /* gallery files
    

    训练模型并测试,修改train.py、test.py中的--test_dir路径/home/hylink/eclipse-workspace/reID/Market/pytorch

    python train.py
    python test.py
    python demo.py --query_index 777
    

    效果展示:
    在这里插入图片描述

    修改test.py(将原gallery和query生成底库,改为只生成gallery底库)

    # -*- coding: utf-8 -*-
    
    from __future__ import print_function, division
    
    import argparse
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.optim import lr_scheduler
    from torch.autograd import Variable
    import numpy as np
    import torchvision
    from torchvision import datasets, models, transforms
    import time
    import os
    import scipy.io
    from model import ft_net, ft_net_dense, PCB, PCB_test
    
    ######################################################################
    # Options
    # --------
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')
    parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
    parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data')
    parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
    parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
    parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
    parser.add_argument('--PCB', action='store_true', help='use PCB' )
    parser.add_argument('--multi', action='store_true', help='use multiple query' )
    
    opt = parser.parse_args()
    
    str_ids = opt.gpu_ids.split(',')
    #which_epoch = opt.which_epoch
    name = opt.name
    test_dir = opt.test_dir
    
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >=0:
            gpu_ids.append(id)
    
    # set gpu ids
    if len(gpu_ids)>0:
        torch.cuda.set_device(gpu_ids[0])
    
    ######################################################################
    # Load Data
    # ---------
    #
    # We will use torchvision and torch.utils.data packages for loading the
    # data.
    #
    data_transforms = transforms.Compose([
            transforms.Resize((288,144), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ############### Ten Crop        
            #transforms.TenCrop(224),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.ToTensor()(crop) 
              #      for crop in crops]
               # )),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop)
              #       for crop in crops]
              # ))
    ])
    
    if opt.PCB:
        data_transforms = transforms.Compose([
            transforms.Resize((384,192), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ])
    
    
    data_dir = test_dir
    
    if opt.multi:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
    else:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery']}
    #class_names = image_datasets['query'].classes
    use_gpu = torch.cuda.is_available()
    
    ######################################################################
    # Load model
    #---------------------------
    def load_network(network):
        save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
        network.load_state_dict(torch.load(save_path))
        return network
    
    
    ######################################################################
    # Extract feature
    # ----------------------
    #
    # Extract feature from  a trained model.
    #
    def fliplr(img):
        '''flip horizontal'''
        inv_idx = torch.arange(img.size(3)-1,-1,-1).long()  # N x C x H x W
        img_flip = img.index_select(3,inv_idx)
        return img_flip
    
    def extract_feature(model,dataloaders):
        features = torch.FloatTensor()
        count = 0
        for data in dataloaders:
            img, label = data
            n, c, h, w = img.size()
            count += n
            print(count)
            if opt.use_dense:
                ff = torch.FloatTensor(n,1024).zero_()
            else:
                ff = torch.FloatTensor(n,2048).zero_()
            if opt.PCB:
                ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts
            for i in range(2):
                if(i==1):
                    img = fliplr(img)
                input_img = Variable(img.cuda())
                outputs = model(input_img) 
                f = outputs.data.cpu()
                ff = ff+f
            # norm feature
            if opt.PCB:
                # feature size (n,2048,6)
                # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
                # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) 
                ff = ff.div(fnorm.expand_as(ff))
                ff = ff.view(ff.size(0), -1)
            else:
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))
    
            features = torch.cat((features,ff), 0)
        return features
    
    def get_id(img_path):
        camera_id = []
        labels = []
        for path, v in img_path:
            #filename = path.split('/')[-1]
            filename = os.path.basename(path)
            label = filename[0:4]
            camera = filename.split('c')[1]
            if label[0:2]=='-1':
                labels.append(-1)
            else:
                labels.append(int(label))
            camera_id.append(int(camera[0]))
        return camera_id, labels
    
    gallery_path = image_datasets['gallery'].imgs
    #query_path = image_datasets['query'].imgs
    
    gallery_cam,gallery_label = get_id(gallery_path)
    #query_cam,query_label = get_id(query_path)
    
    if opt.multi:
        mquery_path = image_datasets['multi-query'].imgs
        mquery_cam,mquery_label = get_id(mquery_path)
    
    ######################################################################
    # Load Collected data Trained model
    print('-------test-----------')
    if opt.use_dense:
        model_structure = ft_net_dense(751)
    else:
        model_structure = ft_net(751)
    
    if opt.PCB:
        model_structure = PCB(751)
    
    model = load_network(model_structure)
    
    # Remove the final fc layer and classifier layer
    if not opt.PCB:
        model.model.fc = nn.Sequential()
        model.classifier = nn.Sequential()
    else:
        model = PCB_test(model)
    
    # Change to test mode
    model = model.eval()
    if use_gpu:
        model = model.cuda()
    
    # Extract feature
    gallery_feature = extract_feature(model,dataloaders['gallery'])
    #query_feature = extract_feature(model,dataloaders['query'])
    if opt.multi:
        mquery_feature = extract_feature(model,dataloaders['multi-query'])
        
    # Save to Matlab for check
    #result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
    result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam}
    
    scipy.io.savemat('pytorch_result.mat',result)
    if opt.multi:
        result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam}
        scipy.io.savemat('multi_query.mat',result)
    
    
    

    修改demo.py(将query路径下的图片生成特征并于gallery底库进行比对并展示)

    # -*- coding: utf-8 -*-
    
    from __future__ import print_function, division
    
    import argparse
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.optim import lr_scheduler
    from torch.autograd import Variable
    import numpy as np
    import torchvision
    from torchvision import datasets, models, transforms
    import time
    import os
    import scipy.io
    import matplotlib.pyplot as plt
    from model import ft_net, ft_net_dense, PCB, PCB_test
    
    ######################################################################
    # Options
    # --------
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')
    parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
    parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data')
    parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
    parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
    parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
    parser.add_argument('--PCB', action='store_true', help='use PCB' )
    parser.add_argument('--multi', action='store_true', help='use multiple query' )
    parser.add_argument('--query_index', default=3, type=int, help='test_image_index')
    
    opt = parser.parse_args()
    
    str_ids = opt.gpu_ids.split(',')
    #which_epoch = opt.which_epoch
    name = opt.name
    test_dir = opt.test_dir
    
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >=0:
            gpu_ids.append(id)
    
    # set gpu ids
    if len(gpu_ids)>0:
        torch.cuda.set_device(gpu_ids[0])
    
    ######################################################################
    # Load Data
    # ---------
    #
    # We will use torchvision and torch.utils.data packages for loading the
    # data.
    #
    data_transforms = transforms.Compose([
            transforms.Resize((288,144), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ############### Ten Crop        
            #transforms.TenCrop(224),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.ToTensor()(crop) 
              #      for crop in crops]
               # )),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop)
              #       for crop in crops]
              # ))
    ])
    
    if opt.PCB:
        data_transforms = transforms.Compose([
            transforms.Resize((384,192), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ])
    
    
    data_dir = test_dir
    
    if opt.multi:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
    else:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery','query']}
    class_names = image_datasets['query'].classes
    use_gpu = torch.cuda.is_available()
    
    ######################################################################
    # Load model
    #---------------------------
    def load_network(network):
        save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
        network.load_state_dict(torch.load(save_path))
        return network
    
    
    ######################################################################
    # Extract feature
    # ----------------------
    #
    # Extract feature from  a trained model.
    #
    def fliplr(img):
        '''flip horizontal'''
        inv_idx = torch.arange(img.size(3)-1,-1,-1).long()  # N x C x H x W
        img_flip = img.index_select(3,inv_idx)
        return img_flip
    
    def extract_feature(model,dataloaders):
        features = torch.FloatTensor()
        count = 0
        for data in dataloaders:
            img, label = data
            n, c, h, w = img.size()
            count += n
            print(count)
            if opt.use_dense:
                ff = torch.FloatTensor(n,1024).zero_()
            else:
                ff = torch.FloatTensor(n,2048).zero_()
            if opt.PCB:
                ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts
            for i in range(2):
                if(i==1):
                    img = fliplr(img)
                input_img = Variable(img.cuda())
                outputs = model(input_img) 
                f = outputs.data.cpu()
                ff = ff+f
            # norm feature
            if opt.PCB:
                # feature size (n,2048,6)
                # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
                # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) 
                ff = ff.div(fnorm.expand_as(ff))
                ff = ff.view(ff.size(0), -1)
            else:
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))
    
            features = torch.cat((features,ff), 0)
        return features
    
    def get_id(img_path):
        camera_id = []
        labels = []
        for path, v in img_path:
            #filename = path.split('/')[-1]
            filename = os.path.basename(path)
            label = filename[0:4]
            camera = filename.split('c')[1]
            if label[0:2]=='-1':
                labels.append(-1)
            else:
                labels.append(int(label))
            camera_id.append(int(camera[0]))
        return camera_id, labels
    
    query_path = image_datasets['query'].imgs
    
    query_cam,query_label = get_id(query_path)
    
    if opt.multi:
        mquery_path = image_datasets['multi-query'].imgs
        mquery_cam,mquery_label = get_id(mquery_path)
    
    ######################################################################
    # Load Collected data Trained model
    print('-------test-----------')
    if opt.use_dense:
        model_structure = ft_net_dense(751)
    else:
        model_structure = ft_net(751)
    
    if opt.PCB:
        model_structure = PCB(751)
    
    model = load_network(model_structure)
    
    # Remove the final fc layer and classifier layer
    if not opt.PCB:
        model.model.fc = nn.Sequential()
        model.classifier = nn.Sequential()
    else:
        model = PCB_test(model)
    
    # Change to test mode
    model = model.eval()
    if use_gpu:
        model = model.cuda()
    
    # Extract feature
    
    query_feature = extract_feature(model,dataloaders['query'])
    
    ######################################################################
    ######################################################################
    
    
    def imshow(path, title=None):
        """Imshow for Tensor."""
        im = plt.imread(path)
        plt.imshow(im)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)  # pause a bit so that plots are updated
    
    ######################################################################
    result = scipy.io.loadmat('pytorch_result.mat')
    gallery_feature = torch.FloatTensor(result['gallery_f'])
    gallery_cam = result['gallery_cam'][0]
    gallery_label = result['gallery_label'][0]
    
    query_feature = query_feature.cuda()
    gallery_feature = gallery_feature.cuda()
    
    #######################################################################
    # sort the images
    def sort_img(qf, ql, qc, gf, gl, gc):
        query = qf.view(-1,1)
        # print(query.shape)
        score = torch.mm(gf,query)
        score = score.squeeze(1).cpu()
        score = score.numpy()
        # predict index
        index = np.argsort(score)  #from small to large
        index = index[::-1]
        # index = index[0:2000]
        # good index
        query_index = np.argwhere(gl==ql)
        #same camera
        camera_index = np.argwhere(gc==qc)
    
        junk_index1 = np.argwhere(gl==-1)
        junk_index2 = np.intersect1d(query_index, camera_index)
        junk_index = np.append(junk_index2, junk_index1) 
    
        mask = np.in1d(index, junk_index, invert=True)
        index = index[mask]
        return index
    
    i = opt.query_index
    index = sort_img(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
    
    ########################################################################
    # Visualize the rank result
    
    query_path, _ = image_datasets['query'].imgs[i]
    query_label = query_label[i]
    print(query_path)
    print('Top 10 images are as follow:')
    try: # Visualize Ranking Result 
        # Graphical User Interface is needed
        fig = plt.figure(figsize=(16,4))
        ax = plt.subplot(1,11,1)
        ax.axis('off')
        imshow(query_path,'query')
        for i in range(10):
            ax = plt.subplot(1,11,i+2)
            ax.axis('off')
            img_path, _ = image_datasets['gallery'].imgs[index[i]]
            label = gallery_label[index[i]]
            imshow(img_path)
            if label == query_label:
                ax.set_title('%d'%(i+1), color='green')
            else:
                ax.set_title('%d'%(i+1), color='red')
            print(img_path)
    except RuntimeError:
        for i in range(10):
            img_path = image_datasets.imgs[index[i]]
            print(img_path[0])
        print('If you want to see the visualization of the ranking result, graphical user interface is needed.')
    
    fig.savefig("show.png")
        
    

    自定义底库放置在pytorch/gallery/
    在这里插入图片描述
    自定义查询库放置在pytorch/query/
    在这里插入图片描述
    效果展示
    在这里插入图片描述

  • 相关阅读:
    使用jaxb用xsd生成java类
    EMF保存CDATA
    windows builder里面的可伸缩面板
    使用eclipse open type对话框
    eclipse中toolbar位置的系统URI
    bzoj 4414 数量积 结论题
    bzoj 4402 Claris的剑 组合数学
    bzoj 4206 最大团 几何+lis
    bzoj 3676 [Apio2014]回文串 回文自动机
    bzoj 3670 [Noi2014]动物园 kmp
  • 原文地址:https://www.cnblogs.com/gmhappy/p/11864018.html
Copyright © 2011-2022 走看看