zoukankan      html  css  js  c++  java
  • 行人重识别(ReID) ——基于Person_reID_baseline_pytorch修改业务流程

    下载Person_reID_baseline_pytorch地址:https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial
    下载Market1501数据集:http://www.liangzheng.org/Project/project_reid.html
    Market1501数据集结构:

    ├── Market/
    │   ├── bounding_box_test/          /* Files for testing (candidate images pool)
    │   ├── bounding_box_train/         /* Files for training 
    │   ├── gt_bbox/                    /* We do not use it 
    │   ├── gt_query/                   /* Files for multiple query testing 
    │   ├── query/                      /* Files for testing (query images)
    │   ├── readme.txt
    

    修改--test_dir路径,执行python prepare.py之后的数据集结构:

    ├── Market/
    │   ├── bounding_box_test/          /* Files for testing (candidate images pool)
    │   ├── bounding_box_train/         /* Files for training 
    │   ├── gt_bbox/                    /* We do not use it 
    │   ├── gt_query/                   /* Files for multiple query testing 
    │   ├── query/                      /* Files for testing (query images)
    │   ├── readme.txt
    │   ├── pytorch/
    │       ├── train/                   /* train 
    │           ├── 0002
    |           ├── 0007
    |           ...
    │       ├── val/                     /* val
    │       ├── train_all/               /* train+val      
    │       ├── query/                   /* query files  
    │       ├── gallery/                 /* gallery files
    

    训练模型并测试,修改train.py、test.py中的--test_dir路径/home/hylink/eclipse-workspace/reID/Market/pytorch

    python train.py
    python test.py
    python demo.py --query_index 777
    

    效果展示:
    在这里插入图片描述

    修改test.py(将原gallery和query生成底库,改为只生成gallery底库)

    # -*- coding: utf-8 -*-
    
    from __future__ import print_function, division
    
    import argparse
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.optim import lr_scheduler
    from torch.autograd import Variable
    import numpy as np
    import torchvision
    from torchvision import datasets, models, transforms
    import time
    import os
    import scipy.io
    from model import ft_net, ft_net_dense, PCB, PCB_test
    
    ######################################################################
    # Options
    # --------
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')
    parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
    parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data')
    parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
    parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
    parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
    parser.add_argument('--PCB', action='store_true', help='use PCB' )
    parser.add_argument('--multi', action='store_true', help='use multiple query' )
    
    opt = parser.parse_args()
    
    str_ids = opt.gpu_ids.split(',')
    #which_epoch = opt.which_epoch
    name = opt.name
    test_dir = opt.test_dir
    
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >=0:
            gpu_ids.append(id)
    
    # set gpu ids
    if len(gpu_ids)>0:
        torch.cuda.set_device(gpu_ids[0])
    
    ######################################################################
    # Load Data
    # ---------
    #
    # We will use torchvision and torch.utils.data packages for loading the
    # data.
    #
    data_transforms = transforms.Compose([
            transforms.Resize((288,144), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ############### Ten Crop        
            #transforms.TenCrop(224),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.ToTensor()(crop) 
              #      for crop in crops]
               # )),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop)
              #       for crop in crops]
              # ))
    ])
    
    if opt.PCB:
        data_transforms = transforms.Compose([
            transforms.Resize((384,192), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ])
    
    
    data_dir = test_dir
    
    if opt.multi:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
    else:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery']}
    #class_names = image_datasets['query'].classes
    use_gpu = torch.cuda.is_available()
    
    ######################################################################
    # Load model
    #---------------------------
    def load_network(network):
        save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
        network.load_state_dict(torch.load(save_path))
        return network
    
    
    ######################################################################
    # Extract feature
    # ----------------------
    #
    # Extract feature from  a trained model.
    #
    def fliplr(img):
        '''flip horizontal'''
        inv_idx = torch.arange(img.size(3)-1,-1,-1).long()  # N x C x H x W
        img_flip = img.index_select(3,inv_idx)
        return img_flip
    
    def extract_feature(model,dataloaders):
        features = torch.FloatTensor()
        count = 0
        for data in dataloaders:
            img, label = data
            n, c, h, w = img.size()
            count += n
            print(count)
            if opt.use_dense:
                ff = torch.FloatTensor(n,1024).zero_()
            else:
                ff = torch.FloatTensor(n,2048).zero_()
            if opt.PCB:
                ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts
            for i in range(2):
                if(i==1):
                    img = fliplr(img)
                input_img = Variable(img.cuda())
                outputs = model(input_img) 
                f = outputs.data.cpu()
                ff = ff+f
            # norm feature
            if opt.PCB:
                # feature size (n,2048,6)
                # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
                # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) 
                ff = ff.div(fnorm.expand_as(ff))
                ff = ff.view(ff.size(0), -1)
            else:
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))
    
            features = torch.cat((features,ff), 0)
        return features
    
    def get_id(img_path):
        camera_id = []
        labels = []
        for path, v in img_path:
            #filename = path.split('/')[-1]
            filename = os.path.basename(path)
            label = filename[0:4]
            camera = filename.split('c')[1]
            if label[0:2]=='-1':
                labels.append(-1)
            else:
                labels.append(int(label))
            camera_id.append(int(camera[0]))
        return camera_id, labels
    
    gallery_path = image_datasets['gallery'].imgs
    #query_path = image_datasets['query'].imgs
    
    gallery_cam,gallery_label = get_id(gallery_path)
    #query_cam,query_label = get_id(query_path)
    
    if opt.multi:
        mquery_path = image_datasets['multi-query'].imgs
        mquery_cam,mquery_label = get_id(mquery_path)
    
    ######################################################################
    # Load Collected data Trained model
    print('-------test-----------')
    if opt.use_dense:
        model_structure = ft_net_dense(751)
    else:
        model_structure = ft_net(751)
    
    if opt.PCB:
        model_structure = PCB(751)
    
    model = load_network(model_structure)
    
    # Remove the final fc layer and classifier layer
    if not opt.PCB:
        model.model.fc = nn.Sequential()
        model.classifier = nn.Sequential()
    else:
        model = PCB_test(model)
    
    # Change to test mode
    model = model.eval()
    if use_gpu:
        model = model.cuda()
    
    # Extract feature
    gallery_feature = extract_feature(model,dataloaders['gallery'])
    #query_feature = extract_feature(model,dataloaders['query'])
    if opt.multi:
        mquery_feature = extract_feature(model,dataloaders['multi-query'])
        
    # Save to Matlab for check
    #result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
    result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam}
    
    scipy.io.savemat('pytorch_result.mat',result)
    if opt.multi:
        result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam}
        scipy.io.savemat('multi_query.mat',result)
    
    
    

    修改demo.py(将query路径下的图片生成特征并于gallery底库进行比对并展示)

    # -*- coding: utf-8 -*-
    
    from __future__ import print_function, division
    
    import argparse
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.optim import lr_scheduler
    from torch.autograd import Variable
    import numpy as np
    import torchvision
    from torchvision import datasets, models, transforms
    import time
    import os
    import scipy.io
    import matplotlib.pyplot as plt
    from model import ft_net, ft_net_dense, PCB, PCB_test
    
    ######################################################################
    # Options
    # --------
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')
    parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
    parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data')
    parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
    parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
    parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
    parser.add_argument('--PCB', action='store_true', help='use PCB' )
    parser.add_argument('--multi', action='store_true', help='use multiple query' )
    parser.add_argument('--query_index', default=3, type=int, help='test_image_index')
    
    opt = parser.parse_args()
    
    str_ids = opt.gpu_ids.split(',')
    #which_epoch = opt.which_epoch
    name = opt.name
    test_dir = opt.test_dir
    
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >=0:
            gpu_ids.append(id)
    
    # set gpu ids
    if len(gpu_ids)>0:
        torch.cuda.set_device(gpu_ids[0])
    
    ######################################################################
    # Load Data
    # ---------
    #
    # We will use torchvision and torch.utils.data packages for loading the
    # data.
    #
    data_transforms = transforms.Compose([
            transforms.Resize((288,144), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ############### Ten Crop        
            #transforms.TenCrop(224),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.ToTensor()(crop) 
              #      for crop in crops]
               # )),
            #transforms.Lambda(lambda crops: torch.stack(
             #   [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop)
              #       for crop in crops]
              # ))
    ])
    
    if opt.PCB:
        data_transforms = transforms.Compose([
            transforms.Resize((384,192), interpolation=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ])
    
    
    data_dir = test_dir
    
    if opt.multi:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
    else:
        image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
        dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                                 shuffle=False, num_workers=16) for x in ['gallery','query']}
    class_names = image_datasets['query'].classes
    use_gpu = torch.cuda.is_available()
    
    ######################################################################
    # Load model
    #---------------------------
    def load_network(network):
        save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
        network.load_state_dict(torch.load(save_path))
        return network
    
    
    ######################################################################
    # Extract feature
    # ----------------------
    #
    # Extract feature from  a trained model.
    #
    def fliplr(img):
        '''flip horizontal'''
        inv_idx = torch.arange(img.size(3)-1,-1,-1).long()  # N x C x H x W
        img_flip = img.index_select(3,inv_idx)
        return img_flip
    
    def extract_feature(model,dataloaders):
        features = torch.FloatTensor()
        count = 0
        for data in dataloaders:
            img, label = data
            n, c, h, w = img.size()
            count += n
            print(count)
            if opt.use_dense:
                ff = torch.FloatTensor(n,1024).zero_()
            else:
                ff = torch.FloatTensor(n,2048).zero_()
            if opt.PCB:
                ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts
            for i in range(2):
                if(i==1):
                    img = fliplr(img)
                input_img = Variable(img.cuda())
                outputs = model(input_img) 
                f = outputs.data.cpu()
                ff = ff+f
            # norm feature
            if opt.PCB:
                # feature size (n,2048,6)
                # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
                # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) 
                ff = ff.div(fnorm.expand_as(ff))
                ff = ff.view(ff.size(0), -1)
            else:
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))
    
            features = torch.cat((features,ff), 0)
        return features
    
    def get_id(img_path):
        camera_id = []
        labels = []
        for path, v in img_path:
            #filename = path.split('/')[-1]
            filename = os.path.basename(path)
            label = filename[0:4]
            camera = filename.split('c')[1]
            if label[0:2]=='-1':
                labels.append(-1)
            else:
                labels.append(int(label))
            camera_id.append(int(camera[0]))
        return camera_id, labels
    
    query_path = image_datasets['query'].imgs
    
    query_cam,query_label = get_id(query_path)
    
    if opt.multi:
        mquery_path = image_datasets['multi-query'].imgs
        mquery_cam,mquery_label = get_id(mquery_path)
    
    ######################################################################
    # Load Collected data Trained model
    print('-------test-----------')
    if opt.use_dense:
        model_structure = ft_net_dense(751)
    else:
        model_structure = ft_net(751)
    
    if opt.PCB:
        model_structure = PCB(751)
    
    model = load_network(model_structure)
    
    # Remove the final fc layer and classifier layer
    if not opt.PCB:
        model.model.fc = nn.Sequential()
        model.classifier = nn.Sequential()
    else:
        model = PCB_test(model)
    
    # Change to test mode
    model = model.eval()
    if use_gpu:
        model = model.cuda()
    
    # Extract feature
    
    query_feature = extract_feature(model,dataloaders['query'])
    
    ######################################################################
    ######################################################################
    
    
    def imshow(path, title=None):
        """Imshow for Tensor."""
        im = plt.imread(path)
        plt.imshow(im)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)  # pause a bit so that plots are updated
    
    ######################################################################
    result = scipy.io.loadmat('pytorch_result.mat')
    gallery_feature = torch.FloatTensor(result['gallery_f'])
    gallery_cam = result['gallery_cam'][0]
    gallery_label = result['gallery_label'][0]
    
    query_feature = query_feature.cuda()
    gallery_feature = gallery_feature.cuda()
    
    #######################################################################
    # sort the images
    def sort_img(qf, ql, qc, gf, gl, gc):
        query = qf.view(-1,1)
        # print(query.shape)
        score = torch.mm(gf,query)
        score = score.squeeze(1).cpu()
        score = score.numpy()
        # predict index
        index = np.argsort(score)  #from small to large
        index = index[::-1]
        # index = index[0:2000]
        # good index
        query_index = np.argwhere(gl==ql)
        #same camera
        camera_index = np.argwhere(gc==qc)
    
        junk_index1 = np.argwhere(gl==-1)
        junk_index2 = np.intersect1d(query_index, camera_index)
        junk_index = np.append(junk_index2, junk_index1) 
    
        mask = np.in1d(index, junk_index, invert=True)
        index = index[mask]
        return index
    
    i = opt.query_index
    index = sort_img(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
    
    ########################################################################
    # Visualize the rank result
    
    query_path, _ = image_datasets['query'].imgs[i]
    query_label = query_label[i]
    print(query_path)
    print('Top 10 images are as follow:')
    try: # Visualize Ranking Result 
        # Graphical User Interface is needed
        fig = plt.figure(figsize=(16,4))
        ax = plt.subplot(1,11,1)
        ax.axis('off')
        imshow(query_path,'query')
        for i in range(10):
            ax = plt.subplot(1,11,i+2)
            ax.axis('off')
            img_path, _ = image_datasets['gallery'].imgs[index[i]]
            label = gallery_label[index[i]]
            imshow(img_path)
            if label == query_label:
                ax.set_title('%d'%(i+1), color='green')
            else:
                ax.set_title('%d'%(i+1), color='red')
            print(img_path)
    except RuntimeError:
        for i in range(10):
            img_path = image_datasets.imgs[index[i]]
            print(img_path[0])
        print('If you want to see the visualization of the ranking result, graphical user interface is needed.')
    
    fig.savefig("show.png")
        
    

    自定义底库放置在pytorch/gallery/
    在这里插入图片描述
    自定义查询库放置在pytorch/query/
    在这里插入图片描述
    效果展示
    在这里插入图片描述

  • 相关阅读:
    (转)Hibernate框架基础——在Hibernate中java对象的状态
    (转)Hibernate框架基础——cascade属性
    (转)Hibernate框架基础——多对多关联关系映射
    (转)Hibernate框架基础——一对多关联关系映射
    (转)Hibernate框架基础——映射集合属性
    (转)Hibernate框架基础——映射主键属性
    (转)Hibernate框架基础——映射普通属性
    (转)Eclipse在线配置Hibernate Tools
    人物志---川航8633事件
    日常英语---200204(moderately)
  • 原文地址:https://www.cnblogs.com/gmhappy/p/11864018.html
Copyright © 2011-2022 走看看