zoukankan      html  css  js  c++  java
  • Open Set Domain Adaptation by Backpropagation(OSBP)论文数字数据集复现

    Open Set Domain Adaptation by Backpropagation(OSBP)论文数字数据集复现

    1.准备数据集

    MNIST数据集:28*28,共70000张图片,10类数字

    USPS数据集:16*16,共20000张图片,10类数字

    SVHN数据集:32*32,共73257张图片,10类数字

    由于torchvision.datasets中自带的数据集没有USPS数据集,所以使用一个类设置数据集

    """Dataset setting and data loader for USPS.
    Modified from
    https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
    """
    
    import gzip
    import os
    import pickle
    import urllib
    from PIL import Image
    
    import numpy as np
    import torch
    import torch.utils.data as data
    from torch.utils.data.sampler import WeightedRandomSampler
    from torchvision import datasets, transforms
    
    
    class USPS(data.Dataset):
        """USPS Dataset.
        Args:
            root (string): Root directory of dataset where dataset file exist.
            train (bool, optional): If True, resample from dataset randomly.
            download (bool, optional): If true, downloads the dataset
                from the internet and puts it in root directory.
                If dataset is already downloaded, it is not downloaded again.
            transform (callable, optional): A function/transform that takes in
                an PIL image and returns a transformed version.
                E.g, ``transforms.RandomCrop``
        """
    
        url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
    
        def __init__(self, root, train=True, transform=None, download=False):
            """Init USPS dataset."""
            # init params
            self.root = os.path.expanduser(root)
            self.filename = "usps_28x28.pkl"
            self.train = train
            # Num of Train = 7438, Num ot Test 1860
            self.transform = transform
            self.dataset_size = None
    
            # download dataset.
            if download:
                self.download()
            if not self._check_exists():
                raise RuntimeError("Dataset not found." +
                                   " You can use download=True to download it")
    
            self.train_data, self.train_labels = self.load_samples()
            if self.train:
                total_num_samples = self.train_labels.shape[0]
                indices = np.arange(total_num_samples)
                self.train_data = self.train_data[indices[0:self.dataset_size], ::]
                self.train_labels = self.train_labels[indices[0:self.dataset_size]]
            self.train_data *= 255.0
            self.train_data = np.squeeze(self.train_data).astype(np.uint8)
    
        def __getitem__(self, index):
            """Get images and target for data loader.
            Args:
                index (int): Index
            Returns:
                tuple: (image, target) where target is index of the target class.
            """
            img, label = self.train_data[index], self.train_labels[index]
            img = Image.fromarray(img, mode='L')
            img = img.copy()
            if self.transform is not None:
                img = self.transform(img)
            return img, label.astype("int64")
    
        def __len__(self):
            """Return size of dataset."""
            return len(self.train_data)
    
        def _check_exists(self):
            """Check if dataset is download and in right place."""
            return os.path.exists(os.path.join(self.root, self.filename))
    
        def download(self):
            """Download dataset."""
            filename = os.path.join(self.root, self.filename)
            dirname = os.path.dirname(filename)
            if not os.path.isdir(dirname):
                os.makedirs(dirname)
            if os.path.isfile(filename):
                return
            print("Download %s to %s" % (self.url, os.path.abspath(filename)))
            urllib.request.urlretrieve(self.url, filename)
            print("[DONE]")
            return
    
        def load_samples(self):
            """Load sample images from dataset."""
            filename = os.path.join(self.root, self.filename)
            f = gzip.open(filename, "rb")
            data_set = pickle.load(f, encoding="bytes")
            f.close()
            if self.train:
                images = data_set[0][0]
                labels = data_set[0][1]
                self.dataset_size = labels.shape[0]
            else:
                images = data_set[1][0]
                labels = data_set[1][1]
                self.dataset_size = labels.shape[0]
            return images, labels
    
    
    from torchvision import transforms
    from torchvision.datasets import MNIST
    from torchvision.datasets import SVHN
    
    from .mnist import *
    from .svhn import *
    from .usps import *
    
    def get_dataset(task):
        if task == 's2m':
            #注意这里的SVHN与MNIST、USP数据路径的写法不一样,
            train_dataset = SVHN('datasets/SVHN', split='train', download=False,#split='train':选择使用SVHN的train数据集
                    transform=transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                    ]))
            
            test_dataset = MNIST('datasets', train=True, download=False,
                    transform=transforms.Compose([
                        transforms.Resize(32),
                        transforms.Lambda(lambda x: x.convert("RGB")),#因为SVHN数据集中的数据是通道为3的彩色图片,为了一致将MNIST数据集也转换为彩色
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#前面的(0.5,0.5,0.5) 是 R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5)是三个通道的标准差,
                    ]))
        elif task == 'u2m':
            train_dataset = USPS('datasets', train=True, download=False,
                    transform=transforms.Compose([
                        transforms.RandomCrop(28, padding=4),
                        transforms.RandomRotation(10),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))#这里的通道数为1,注意与上面的区分
                    ]))
            
            test_dataset = MNIST('datasets', train=True, download=False,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ]))
        else:
            train_dataset = MNIST('datasets', train=True, download=False,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ]))
    
            test_dataset = USPS('datasets', train=True, download=False,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ]))
        
        return relabel_dataset(train_dataset, test_dataset, task)
    
    def relabel_dataset(train_dataset, test_dataset, task):
        image_path = []
        image_label = []
        if task == 's2m':
            for i in range(len(train_dataset.data)):
                if int(train_dataset.labels[i]) < 5:#将label小于5的加入源域的数据,源域没有属于未知类别的数据,所以5~9不加入源域
                    image_path.append(train_dataset.data[i])
                    image_label.append(train_dataset.labels[i])
            train_dataset.data = image_path
            train_dataset.labels = image_label
        else:
            for i in range(len(train_dataset.train_data)):
                if int(train_dataset.train_labels[i]) < 5:
                    image_path.append(train_dataset.train_data[i])
                    image_label.append(train_dataset.train_labels[i])
            train_dataset.train_data = image_path
            train_dataset.train_labels = image_label
    
        for i in range(len(test_dataset.train_data)):#在MNIST的训练中使用了所有的样本,将5~9样本的标签统一设为5,表示目标域中的未知类
            if int(test_dataset.train_labels[i]) >= 5:
                test_dataset.train_labels[i] = 5
            
        return train_dataset, test_dataset
    

    2.模型结构

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Conv_Block(nn.Module):#卷积模块:一个卷积+一个relu激活函数+一个BatchNorm归一化
        def __init__(self, in_channels, out_channels, kernel_size, stride=1):
            super(Conv_Block, self).__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride)
            self.relu = torch.nn.LeakyReLU()
            self.bn = nn.BatchNorm2d(out_channels)
        
        def forward(self, x):
            x = self.conv(x)
            x = self.relu(x)
            x = self.bn(x)
            return x
    
    class Dense_Block(nn.Module):#全连接层模块:一个全连接+一个leaky_relu激活函数+一个BatchNorm归一化
        def __init__(self, in_features, out_features):
            super(Dense_Block, self).__init__()
            self.fc = nn.Linear(in_features, out_features)
            self.relu = torch.nn.LeakyReLU()
            self.bn = nn.BatchNorm1d(out_features)
        
        def forward(self, x):
            x = self.fc(x)
            x = self.relu(x)
            x = self.bn(x)
            return x
    
    
    
    # https://zhuanlan.zhihu.com/p/263827804
    #
    # 训练的时候出现错误:
    #
    # RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.
    #
    # 由于你当前的pytorch版本过高,而原代码的版本较低。如果pytorch版本高于1.3会出现该问题。当前版本要求forward过程是静态的,所以需要将原代码进行修改。
    
    
    # class GradReverse(torch.autograd.Function):#梯度反转层,之前的写法是pytorch老版,新版直接用会报错
    #     # def __init__(self, lambd):
    #     #     self.lambd = lambd
    #     #
    #     # def forward(self, x):
    #     #     return x.view_as(x)
    #     #
    #     # def backward(self, grad_output):
    #     #     return (grad_output * -self.lambd)
    #     # 重写父类方法的时候,最好添加默认参数,不然会有warning
    #     @staticmethod
    #     def forward(ctx, x):
    #         result = x.view_as(x)
    #         ctx.save_for_backward(result)
    #         return result
    #
    #     @staticmethod
    #     def backward(ctx, grad_output):
    #         return (grad_output * -1.0)
    #---------------------------------------------------------------------------------------------
    class GradReverse(torch.autograd.Function):
        def __init__(self):
            super(GradReverse, self).__init__()
    
        @ staticmethod
        def forward(ctx, x, lambda_):
            ctx.save_for_backward(lambda_)
            return x.view_as(x)
    
        @ staticmethod
        def backward(ctx, grad_output):
            lambda_, = ctx.saved_variables
            grad_input = grad_output.clone()
            return - lambda_ * grad_input, None
    
    
    # class GRL(nn.Module):
    #     def __init__(self, lambda_=0.):
    #         super(GRL, self).__init__()
    #         self.lambda_ = torch.tensor(lambda_)
    #
    #     def set_lambda(self, lambda_):
    #         self.lambda_ = torch.tensor(lambda_)
    #
    #     def forward(self, x):
    #         return GradReverse.apply(x, self.lambda_)
        #---------------------------------------------------------------------------------------------
    
    
    def grad_reverse(x, lambd=1.0):
        # return GradReverse(lambd)(x)
        lam = torch.tensor(lambd)#注意这里要把lambd的值由typefloat64改为tensor
        return GradReverse.apply(x,lam)
    
    class Generator_s2m(nn.Module):#生成器模型
        def __init__(self):
            super(Generator_s2m, self).__init__()
            self.conv1 = Conv_Block(3, 64, kernel_size=5)    
            self.conv2 = Conv_Block(64, 64, kernel_size=5)
            self.conv3 = Conv_Block(64, 128, kernel_size=3, stride=2)
            self.conv4 = Conv_Block(128, 128, kernel_size=3, stride=2)
            self.fc1 = Dense_Block(3200, 100)
            self.fc2 = Dense_Block(100, 100)
            
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.conv3(x)
            x = self.conv4(x)
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
            x = self.fc2(x)
            return x
    
    class Classifier_s2m(nn.Module):#分类器模型
        def __init__(self, n_output):
            super(Classifier_s2m, self).__init__()
            self.fc = nn.Linear(100, n_output)
    
        def forward(self, x):
            x = self.fc(x)
            return x
    
    class Generator_u2m(nn.Module):
        def __init__(self):
            super(Generator_u2m, self).__init__()
            self.conv1 = Conv_Block(1, 20, kernel_size=5)
            self.pool1 = nn.MaxPool2d(2, stride=2)
            self.conv2 = Conv_Block(20, 50, kernel_size=5)
            self.pool2 = nn.MaxPool2d(2, stride=2)
            self.drop = nn.Dropout()
            self.fc = Dense_Block(800, 500)
            
        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            x = self.conv2(x)
            x = self.pool2(x)
            x = x.view(x.size(0), -1)
            x = self.drop(x)
            x = self.fc(x)
            return x
    
    class Classifier_u2m(nn.Module):
        def __init__(self, n_output):
            super(Classifier_u2m, self).__init__()
            self.fc = nn.Linear(500, n_output)
    
        def forward(self, x):
            x = self.fc(x)
            return x
    
    class Net(nn.Module):
        def __init__(self, task='s2m'):
            super(Net, self).__init__()
            if task == 's2m':
                self.generator = Generator_s2m()
                self.classifier = Classifier_s2m(6)
            elif task =='u2m' or task == 'm2u':
                self.generator = Generator_u2m()
                self.classifier = Classifier_u2m(6)
                    
            for m in self.modules():#对模型中参数的初始化
                if isinstance(m, nn.Conv2d):#卷积层使用了relu函数,这里使用的kaiming正态分布
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm1d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
    
        def forward(self, x, constant = 1, adaption = False):
            x = self.generator(x)
            if adaption == True:
                x = grad_reverse(x, constant)
            x = self.classifier(x)
            return x
    

    3.训练(SVHN( ightarrow) MNIST)

    使用所有编号在0到4范围内的SVHN训练样本来训练网络,在MNIST的训练中使用了所有的样本。

    from __future__ import print_function
    import argparse
    import os
    import numpy as np
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optimizer
    
    from datasets.get_dataset import get_dataset
    import models
    import utils
    
    NUM_CLASSES = 6
    
    # Training settings
    parser = argparse.ArgumentParser(description='Openset-DA SVHN -> MNIST Example')
    parser.add_argument('--task', choices=['s2m', 'u2m', 'm2u'], default='s2m',
                        help='type of task')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--epochs', type=int, default=35, metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--lr-rampdown-epochs', default=201, type=int, metavar='EPOCHS',
                            help='length of learning rate cosine rampdown (>= length of training)')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--grl-rampup-epochs', default=20, type=int, metavar='EPOCHS',
                            help='length of grl rampup')
    parser.add_argument('--weight-decay', '--wd', default=1e-3, type=float,
                        metavar='W', help='weight decay (default: 1e-3)')
    parser.add_argument('--th', type=float, default=0.5, metavar='TH',
                        help='threshold (default: 0.5)')
    parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--gpu', default='0', type=str, metavar='GPU',
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    args = parser.parse_args()
    
    torch.backends.cudnn.benchmark = True#选择合适的卷积方法,加速训练
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu#设置当前使用的GPU设备仅为0号设备 设备名称为’/gpu:0’
    
    source_dataset, target_dataset = get_dataset(args.task)
    
    source_loader = torch.utils.data.DataLoader(source_dataset, 
        batch_size=args.batch_size, shuffle=True, num_workers=0)
    
    target_loader = torch.utils.data.DataLoader(target_dataset,
        batch_size=args.batch_size, shuffle=True, num_workers=0)
    
    model = models.Net(task=args.task).cuda()
    
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,#大或等于1e-3的浮点数,每次更新后的学习率衰减值
                                    nesterov=True)#确定是否使用Nesterov动量
                             
    if args.resume:
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
    
    criterion_bce = nn.BCELoss()
    criterion_cel = nn.CrossEntropyLoss()
    
    best_prec1 = 0
    best_pred_y = []
    best_gt_y = []
    global_step = 0
    total_steps = args.grl_rampup_epochs * len(source_loader)
    
    def train(epoch):
        model.train()
        global global_step
        for batch_idx, (batch_s, batch_t) in enumerate(zip(source_loader, target_loader)):
            adjust_learning_rate(optimizer, epoch, batch_idx, len(source_loader))
            p = global_step / total_steps
            constant = 2. / (1. + np.exp(-10 * p)) - 1#注意这里constan的值其实就是后面梯度反转层的lambda参数,由0逐渐变为1,在DANN的论文中提出,而且这里gamma的取值也是10
    
            data_s, target_s = batch_s
            data_t, target_t = batch_t
    
            data_s, target_s = data_s.cuda(), target_s.cuda(non_blocking=True)
            data_t, target_t = data_t.cuda(), target_t.cuda(non_blocking=True)
    
            batch_size_s = len(target_s)
            batch_size_t = len(target_t)
    
            optimizer.zero_grad()
    
            output_s = model(data_s)#对源域放入数据进行正常的训练,使用源域的数据计算Ls损失
            output_t = model(data_t, constant = constant, adaption = True)#使用梯度反转层,进行模型参数的更新,计算损失Ladv
    
            loss_cel = criterion_cel(output_s, target_s)
    #使用标准交叉熵损失计算损失,对源域样本$x_s$进行正确的分类
    
            output_t_prob_unk = F.softmax(output_t, dim=1)[:,-1]#选取最后一维未知类的概率
    #使用softmax函数将分类器C输出的**逻辑向量变换为样本所属未知类别的概率**
    
            loss_adv = criterion_bce(output_t_prob_unk, torch.tensor([args.th]*batch_size_t).cuda())
            # 使用二元交叉熵损失计算$L_{adv}(x_t)$,从而训练分类器为目标域私有标签空间中的样本$x_t$并建立边界
            loss =  loss_cel + loss_adv
            
            loss.backward()
            optimizer.step()
    
            global_step += 1
    
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}	Constant: {:.4f}'.format(
                    epoch, batch_idx * args.batch_size, len(source_loader.dataset),
                    100. * batch_idx / len(source_loader), loss.item(), constant))
    
    def test(epoch):
        global best_prec1
        model.eval()
        loss = 0
        pred_y = []
        true_y = []
    
        correct = 0
        ema_correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(target_loader):
                data, target = data.cuda(), target.cuda(non_blocking=True)
                output = model(data)
    
                loss += criterion_cel(output, target).item() # sum up batch loss
    
                pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
    
                for i in range(len(pred)):
                    pred_y.append(pred[i].item())
                    true_y.append(target[i].item())
    
                correct += pred.eq(target.view_as(pred)).sum().item()
    
        loss /= len(target_loader.dataset)
    
        utils.cal_acc(true_y, pred_y, NUM_CLASSES)
    
        print('
    Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
    '.format(
            loss, correct, len(target_loader.dataset),
            100. * correct / len(target_loader.dataset)))
    
        prec1 = 100. * correct / len(target_loader.dataset)
        if epoch % 1 == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            utils.save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
            if is_best:
                global best_gt_y
                global best_pred_y
                best_gt_y = true_y
                best_pred_y = pred_y
    
    def adjust_learning_rate(optimizer, epoch, step_in_epoch, total_steps_in_epoch):#在训练中调整学习率的方法
        lr = args.lr
        epoch = epoch + step_in_epoch / total_steps_in_epoch
    
        lr *= utils.cosine_rampdown(epoch, args.lr_rampdown_epochs)
    
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    
    def main():
        try:
            for epoch in range(1, args.epochs + 1):
                train(epoch)
                test(epoch)
            print ("------Best Result-------")
            utils.cal_acc(best_gt_y, best_pred_y, NUM_CLASSES)
        except KeyboardInterrupt:
            print ("------Best Result-------")
            utils.cal_acc(best_gt_y, best_pred_y, NUM_CLASSES)
    
    
    if __name__ == '__main__':
        main()
    
    import shutil
    
    import numpy as np
    from sklearn.metrics import accuracy_score
    
    import torch
    
    
    def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename, 'model_best.pth.tar')
    
    def cal_acc(gt_list, predict_list, num):
        acc_sum = 0
        for n in range(num):
            y = []
            pred_y = []
            for i in range(len(gt_list)):
                gt = gt_list[i]
                predict = predict_list[i]
                if gt == n:
                    y.append(gt)
                    pred_y.append(predict)
            print ('{}: {:4f}'.format(n if n != (num - 1) else 'Unk', accuracy_score(y, pred_y)))
            if n == (num - 1):
                print ('Known Avg Acc: {:4f}'.format(acc_sum / (num - 1)))
            acc_sum += accuracy_score(y, pred_y)
        print ('Avg Acc: {:4f}'.format(acc_sum / num))
        print ('Overall Acc : {:4f}'.format(accuracy_score(gt_list, predict_list)))
    
    def cosine_rampdown(current, rampdown_length):#余弦函数进行学习率的衰减
        """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
        assert 0 <= current <= rampdown_length
        return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
    

    4.结果

    image-20201101120918685

    代码:https://github.com/redhat12345/Domain-Adaptation-Papers-and-Codes

  • 相关阅读:
    SpringMVC(二)
    SpringMVC(一)
    Mybatis之mapper.xml配置文件中的#{}和${}
    Mybatis(二)
    Mybatis(一)
    Linux部署项目
    BOS物流项目第十三天
    Failed to read schema document 'http://www.springframework.org/schema/beans/spring-beans.xsd'
    景点API支持查询携程旅游门票景点详情
    Html引入百度富文本编辑器ueditor及自定义工具栏
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13918238.html
Copyright © 2011-2022 走看看