zoukankan      html  css  js  c++  java
  • pytorch的分布式

    学习:https://zhuanlan.zhihu.com/p/136372142

    DDP:

    from __future__ import print_function
    import sys
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import torch.backends.cudnn as cudnn
    
    import torch.nn.parallel
    import torch.distributed as dist
    import torch.utils.data
    import torch.utils.data.distributed
    
    import random
    import os
    import sys
    import argparse
    import numpy as np
    from InceptionResNetV2 import *
    from swin_transformer import *
    from sklearn.mixture import GaussianMixture
    import dataloader_aliproduct as dataloader
    import torchnet
    import time
    from apex.parallel import DistributedDataParallel as DDP
    from apex.fp16_utils import *
    from apex import amp, optimizers
    from apex.multi_tensor_apply import multi_tensor_applier
    import apex
    
    parser = argparse.ArgumentParser(description='PyTorch WebVision Training')
    parser.add_argument('--batch_size', default=64, type=int, help='train batchsize') 
    parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
    parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
    parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
    parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
    parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
    parser.add_argument('--num_epochs', default=200, type=int)
    parser.add_argument('--id', default='',type=str)
    parser.add_argument('--seed', default=123)
    parser.add_argument('--gpuid', default=0, type=int)
    parser.add_argument('--num_class', default=50030, type=int)
    parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset')
    parser.add_argument('--opt-level', default='O1', type=str)
    parser.add_argument("--local_rank", default=4, type=int)
    
    parser.add_argument('--world-size', default=2, type=int,
                        help='number of nodes for distributed training')
    # parser.add_argument('--rank', default=3, type=int,
    #                     help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str,
                        help='distributed backend')
    parser.add_argument('--gpu', default=-1, type=int,
                        help='GPU id to use.')
    parser.add_argument('--multiprocessing-distributed', action='store_true',
                        help='Use multi-processing distributed training to launch '
                             'N processes per node, which has N GPUs. This is the '
                             'fastest way to use PyTorch for either single node or '
                             'multi node data parallel training')
    
    
    # # cudnn.benchmark = True
    # args = parser.parse_args()
    
    # # torch.cuda.set_device(args.gpuid)
    
    # assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
    
    args = parser.parse_args()
    random.seed(args.seed)
    # torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    args.distributed = args.world_size > 1 or args.multiprocessing_distributed
    ngpus_per_node = torch.cuda.device_count()
    if args.gpu is None:
        print("must specify GPU number")
        #return
    
    print("Use GPU: {} for training".format(args.gpu))
    torch.cuda.set_device(args.local_rank)
    
    # For multiprocessing distributed training, rank needs to be the
    # global rank among all the processes
    dist.init_process_group(backend='nccl', init_method='env://')
    
    print("after dist init")
    
    # Training
    def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader):
        net.train()
        net2.eval() #fix one network and train the other
        
        unlabeled_train_iter = iter(unlabeled_trainloader)    
        num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
        for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):      
            try:
                inputs_u, inputs_u2 = unlabeled_train_iter.next()
            except:
                unlabeled_train_iter = iter(unlabeled_trainloader)
                inputs_u, inputs_u2 = unlabeled_train_iter.next()                 
            batch_size = inputs_x.size(0)
            
            # Transform label to one-hot
            labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)        
            w_x = w_x.view(-1,1).type(torch.FloatTensor) 
    
            inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
            inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
    
            with torch.no_grad():
                # label co-guessing of unlabeled samples
                outputs_u11 = net(inputs_u)
                outputs_u12 = net(inputs_u2)
                outputs_u21 = net2(inputs_u)
                outputs_u22 = net2(inputs_u2)            
                
                pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4       
                ptu = pu**(1/args.T) # temparature sharpening
                
                targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
                targets_u = targets_u.detach()       
                
                # label refinement of labeled samples
                outputs_x = net(inputs_x)
                outputs_x2 = net(inputs_x2)            
                
                px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
                px = w_x*labels_x + (1-w_x)*px              
                ptx = px**(1/args.T) # temparature sharpening 
                           
                targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize           
                targets_x = targets_x.detach()       
            
            # mixmatch
            l = np.random.beta(args.alpha, args.alpha)        
            l = max(l, 1-l)
                    
            all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
    
            idx = torch.randperm(all_inputs.size(0))
    
            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]
    
            mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]        
            mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
                    
            logits = net(mixed_input)
            
            Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
            
            prior = torch.ones(args.num_class)/args.num_class
            prior = prior.cuda()        
            pred_mean = torch.softmax(logits, dim=1).mean(0)
            penalty = torch.sum(prior*torch.log(prior/pred_mean))
           
            loss = Lx + penalty
            # compute gradient and do SGD step
            optimizer.zero_grad()
            # loss.backward()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            if (batch_idx+1)%100 == 1:
                sys.stdout.write('
    ')
                sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d]	 Labeled loss: %.2f'
                        %(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
                sys.stdout.flush()
    
    def warmup(epoch,net,optimizer,dataloader):
        start1 = time.time() 
        net.train()
        num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
        for batch_idx, (inputs, labels, path) in enumerate(dataloader):      
            inputs, labels = inputs.cuda(), labels.cuda() 
            optimizer.zero_grad()
            outputs = net(inputs)               
            loss = CEloss(outputs, labels)   
            
            #penalty = conf_penalty(outputs)
            L = loss #+ penalty      
    
            # L.backward()  
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step() 
            if (batch_idx+1)%100 == 1:
                end1 = time.time() 
                sys.stdout.write('
    ')
                sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d]	 CE-loss: %.4f	 %.4f
    '
                        %(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item(),end1 - start1))
                sys.stdout.flush()
            
            
    def test(epoch,net1,net2,test_loader):
        acc_meter.reset()
        net1.eval()
        net2.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs1 = net1(inputs)
                outputs2 = net2(inputs)           
                outputs = outputs1+outputs2
                _, predicted = torch.max(outputs, 1)                 
                acc_meter.add(outputs,targets)
        accs = acc_meter.value()
        return accs
    
    
    def eval_train(model,all_loss):    
        model.eval()
        num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1
        losses = torch.zeros(len(eval_loader.dataset))    
        with torch.no_grad():
            for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
                inputs, targets = inputs.cuda(), targets.cuda() 
                outputs = model(inputs) 
                loss = CE(outputs, targets)  
                for b in range(inputs.size(0)):
                    losses[index[b]]=loss[b]
                if (batch_idx+1)%100 == 1:       
                    sys.stdout.write('
    ')
                    sys.stdout.write('| Evaluating loss Iter[%3d/%3d]	' %(batch_idx,num_iter)) 
                    sys.stdout.flush()    
                                        
        losses = (losses-losses.min())/(losses.max()-losses.min())    
        all_loss.append(losses)
    
        # fit a two-component GMM to the loss
        input_loss = losses.reshape(-1,1)
        gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
        gmm.fit(input_loss)
        prob = gmm.predict_proba(input_loss) 
        prob = prob[:,gmm.means_.argmin()]         
        return prob,all_loss
    
    def linear_rampup(current, warm_up, rampup_length=16):
        current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
        return args.lambda_u*float(current)
    
    class SemiLoss(object):
        def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
            probs_u = torch.softmax(outputs_u, dim=1)
    
            Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
            Lu = torch.mean((probs_u - targets_u)**2)
    
            return Lx, Lu, linear_rampup(epoch,warm_up)
    
    class NegEntropy(object):
        def __call__(self,outputs):
            probs = torch.softmax(outputs, dim=1)
            return torch.mean(torch.sum(probs.log()*probs, dim=1))
    
    def create_model():
        model = swin_base_patch4_window7_224_in22k()#InceptionResNetV2(num_classes=args.num_class) # #
        model = model.cuda(args.local_rank)
        return model
    
    stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w') 
    test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w')     
    
    warm_up=1
    
    loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_workers=5,root_dir=args.data_path,log=stats_log, num_class=args.num_class)
    
    print('| Building net')
    net1 = create_model()
    net2 = create_model()
    cudnn.benchmark = True
    
    criterion = SemiLoss()
    optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    
    net1, optimizer1 = amp.initialize(net1, optimizer1,
                                        opt_level=args.opt_level)
    net2, optimizer2 = amp.initialize(net2, optimizer2,
                                        opt_level=args.opt_level
                                        )
    net1 = DDP(net1)
    net2 = DDP(net2)#, device_ids=[args.local_rank],output_device=args.local_rank, find_unused_parameters=True)
    CE = nn.CrossEntropyLoss(reduction='none')
    CEloss = nn.CrossEntropyLoss()
    conf_penalty = NegEntropy()
    # print("1")
    all_loss = [[],[]] # save the history of losses from two networks
    acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
    
    best_acc = [0,0]
    for epoch in range(args.num_epochs+1):   
        lr=args.lr
        if epoch >= 40:
            lr /= 10      
        for param_group in optimizer1.param_groups:
            param_group['lr'] = lr       
        for param_group in optimizer2.param_groups:
            param_group['lr'] = lr              
        eval_loader = loader.run('eval_train')  
        web_valloader = loader.run('test')
        # imagenet_valloader = loader.run('imagenet')   
        
        if epoch<warm_up:       
            warmup_trainloader = loader.run('warmup')
            print('Warmup Net1')
            warmup(epoch,net1,optimizer1,warmup_trainloader)    
            print('
    Warmup Net2')
            warmup(epoch,net2,optimizer2,warmup_trainloader) 
       
        else:                
            pred1 = (prob1 > args.p_threshold)      
            pred2 = (prob2 > args.p_threshold)      
            
            print('Train Net1')
            labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide
            train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1  
            
            print('
    Train Net2')
            labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide
            train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2    
    
        
        web_acc = test(epoch,net1,net2,web_valloader)  
        k = 0
        if web_acc[k] > best_acc[k]:
            best_acc[k] = web_acc[k]
            print('| Saving Best Net%d ...'%k)
            save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
            torch.save(net1.state_dict(), save_point)
        k = 1
        if web_acc[k] > best_acc[k]:
            best_acc[k] = web_acc[k]
            print('| Saving Best Net%d ...'%k)
            save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
            torch.save(net2.state_dict(), save_point)
        # imagenet_acc = test(epoch,net1,net2,imagenet_valloader)  
        
        # print("
    | Test Epoch #%d	 WebVision Acc: %.2f%% (%.2f%%) 	 ImageNet Acc: %.2f%% (%.2f%%)
    "%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))  
        # test_log.write('Epoch:%d 	 WebVision Acc: %.2f%% (%.2f%%) 	 ImageNet Acc: %.2f%% (%.2f%%)
    '%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
        # test_log.flush()  
        print("
    | Test Epoch #%d	 WebVision Acc: %.2f%% (%.2f%%) 
    "%(epoch,web_acc[0],web_acc[1]))  
        test_log.write('Epoch:%d 	 WebVision Acc: %.2f%% (%.2f%%) 
    '%(epoch,web_acc[0],web_acc[1]))
        test_log.flush() 
           
        print('
    ==== net 1 evaluate training data loss ====') 
        prob1,all_loss[0]=eval_train(net1,all_loss[0])   
        print('
    ==== net 2 evaluate training data loss ====') 
        prob2,all_loss[1]=eval_train(net2,all_loss[1])
        torch.save(all_loss,'./checkpoint/%s.pth.tar'%(args.id))        
    

      

    multiprocessing:

    from __future__ import print_function
    import sys
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import torch.backends.cudnn as cudnn
    import random
    import os
    import sys
    import argparse
    import numpy as np
    # from InceptionResNetV2 import *
    from swin_transformer import *
    from sklearn.mixture import GaussianMixture
    # import dataloader_webvision as dataloader
    import dataloader_aliproduct as dataloader
    import torchnet
    import torch.multiprocessing as mp
    import time
    from apex.parallel import DistributedDataParallel as DDP
    from apex.fp16_utils import *
    from apex import amp, optimizers
    from apex.multi_tensor_apply import multi_tensor_applier
    import apex
    parser = argparse.ArgumentParser(description='PyTorch WebVision Parallel Training')
    parser.add_argument('--batch_size', default=96, type=int, help='train batchsize') 
    parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
    parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
    parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
    parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
    parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
    parser.add_argument('--num_epochs', default=100, type=int)
    parser.add_argument('--id', default='',type=str)
    parser.add_argument('--seed', default=123)
    parser.add_argument('--gpuid1', default=1, type=int)
    parser.add_argument('--gpuid2', default=2, type=int)
    parser.add_argument('--num_class', default=50030, type=int)
    parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset')
    parser.add_argument('--opt-level', default='O2', type=str)
    
    args = parser.parse_args()
    
    os.environ["CUDA_VISIBLE_DEVICES"] = '%s,%s'%(args.gpuid1,args.gpuid2)
    random.seed(args.seed)
    cuda1 = torch.device('cuda:0')
    cuda2 = torch.device('cuda:1')
    
    # Training
    def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader,device,whichnet):
        criterion = SemiLoss()   
        
        net.train()
        net2.eval() #fix one network and train the other
        
        unlabeled_train_iter = iter(unlabeled_trainloader)    
        num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
        net, optimizer = amp.initialize(net, optimizer,
                                          opt_level=args.opt_level
                                          )
    
        net2, optimizer = amp.initialize(net2, optimizer,
                                          opt_level=args.opt_level
                                          )
        start1  = time.time()
        for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):      
            try:
                inputs_u, inputs_u2 = unlabeled_train_iter.next()
            except:
                unlabeled_train_iter = iter(unlabeled_trainloader)
                inputs_u, inputs_u2 = unlabeled_train_iter.next()                 
            batch_size = inputs_x.size(0)
            
            # Transform label to one-hot
            labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)        
            w_x = w_x.view(-1,1).type(torch.FloatTensor) 
    
            inputs_x, inputs_x2, labels_x, w_x = inputs_x.to(device,non_blocking=True), inputs_x2.to(device,non_blocking=True), labels_x.to(device,non_blocking=True), w_x.to(device,non_blocking=True)
            inputs_u, inputs_u2 = inputs_u.to(device), inputs_u2.to(device)
    
            with torch.no_grad():
                # label co-guessing of unlabeled samples
                outputs_u11 = net(inputs_u)
                outputs_u12 = net(inputs_u2)
                outputs_u21 = net2(inputs_u)
                outputs_u22 = net2(inputs_u2)            
                
                pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4       
                ptu = pu**(1/args.T) # temparature sharpening
                
                targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
                targets_u = targets_u.detach()       
                
                # label refinement of labeled samples
                outputs_x = net(inputs_x)
                outputs_x2 = net(inputs_x2)            
                
                px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
                px = w_x*labels_x + (1-w_x)*px              
                ptx = px**(1/args.T) # temparature sharpening 
                           
                targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize           
                targets_x = targets_x.detach()       
            
            # mixmatch
            l = np.random.beta(args.alpha, args.alpha)        
            l = max(l, 1-l)
                    
            all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
    
            idx = torch.randperm(all_inputs.size(0))
    
            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]
    
            mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]        
            mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
                    
            logits = net(mixed_input)
            
            Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
            
            prior = torch.ones(args.num_class)/args.num_class
            prior = prior.to(device)        
            pred_mean = torch.softmax(logits, dim=1).mean(0)
            penalty = torch.sum(prior*torch.log(prior/pred_mean))
           
            loss = Lx + penalty
            # compute gradient and do SGD step
            optimizer.zero_grad()
            # loss.backward()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            if (batch_idx+1)%100 == 1:
                sys.stdout.write('
    ')
                sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d]	 Labeled loss: %.2f'
                        %(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
                sys.stdout.flush()
    
    def warmup(epoch,net,optimizer,dataloader,device,whichnet):
        CEloss = nn.CrossEntropyLoss()
        acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
        start1  = time.time()
        net.train()
        num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
        net, optimizer = amp.initialize(net, optimizer,
                                          opt_level=args.opt_level
                                          )
        for batch_idx, (inputs, labels, path) in enumerate(dataloader):   
            if batch_idx < 102:   
                inputs, labels = inputs.to(device), labels.to(device,non_blocking=True) 
                optimizer.zero_grad()
                outputs = net(inputs)               
                loss = CEloss(outputs, labels)   
                
                #penalty = conf_penalty(outputs)
                L = loss #+ penalty 
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()     
                # L.backward()  
                optimizer.step() 
                if (batch_idx+1)%100 == 1:
                    end1  = time.time()
                    sys.stdout.write('
    ')
                    sys.stdout.write('%s |%s  Epoch [%3d/%3d] Iter[%4d/%4d]	 CE-loss: %.4f	 %.4f'
                            %(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item(),end1-start1))
                    sys.stdout.flush()
    
            
    def test(epoch,net1,net2,test_loader,device,queue):
        acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
        acc_meter.reset()
        net1.eval()
        net2.eval()
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                inputs, targets = inputs.to(device), targets.to(device,non_blocking=True)
                outputs1 = net1(inputs)
                outputs2 = net2(inputs)           
                outputs = outputs1+outputs2
                _, predicted = torch.max(outputs, 1)                 
                acc_meter.add(outputs,targets)
        accs = acc_meter.value()
        queue.put(accs)
    
    
    def eval_train(eval_loader,model,device,whichnet,queue):   
        CE = nn.CrossEntropyLoss(reduction='none')
        model.eval()
        num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1
        losses = torch.zeros(len(eval_loader.dataset))    
        with torch.no_grad():
            for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
                inputs, targets = inputs.to(device), targets.to(device,non_blocking=True) 
                outputs = model(inputs) 
                loss = CE(outputs, targets)  
                for b in range(inputs.size(0)):
                    losses[index[b]]=loss[b] 
                if (batch_idx+1)%100 == 1:      
                    sys.stdout.write('
    ')
                    sys.stdout.write('|%s Evaluating loss Iter[%3d/%3d]	' %(whichnet,batch_idx,num_iter)) 
                    sys.stdout.flush()    
                                        
        losses = (losses-losses.min())/(losses.max()-losses.min())    
    
        # fit a two-component GMM to the loss
        input_loss = losses.reshape(-1,1)
        gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=1e-3)
        gmm.fit(input_loss)
        prob = gmm.predict_proba(input_loss) 
        prob = prob[:,gmm.means_.argmin()]         
        queue.put(prob)
    
    def linear_rampup(current, warm_up, rampup_length=16):
        current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
        return args.lambda_u*float(current)
    
    class SemiLoss(object):
        def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
            probs_u = torch.softmax(outputs_u, dim=1)
    
            Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
            Lu = torch.mean((probs_u - targets_u)**2)
    
            return Lx, Lu, linear_rampup(epoch,warm_up)
    
    class NegEntropy(object):
        def __call__(self,outputs):
            probs = torch.softmax(outputs, dim=1)
            return torch.mean(torch.sum(probs.log()*probs, dim=1))
    
    def create_model(device):
        #model = InceptionResNetV2(num_classes=args.num_class)
        model = swin_base_patch4_window7_224_in22k()
        model = model.to(device)
        return model
    
    if __name__ == "__main__":
        
        mp.set_start_method('spawn')
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)    
        
        stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w') 
        test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w')         
        
        warm_up=1
    
        loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_class = args.num_class,num_workers=8,root_dir=args.data_path,log=stats_log)
    
        print('| Building net')
        
        net1 = create_model(cuda1)
        net2 = create_model(cuda2)
        
        net1_clone = create_model(cuda2)
        net2_clone = create_model(cuda1)
        
        cudnn.benchmark = True
        
        optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    
        # net1, optimizer1 = amp.initialize(net1, optimizer1,
        #                                   opt_level=args.opt_level
        #                                   )
        # net2, optimizer2 = amp.initialize(net2, optimizer2,
        #                                   opt_level=args.opt_level
        #                                   )
        net1_clone = create_model(cuda2)
        net2_clone = create_model(cuda1)
        
                                         
        # net1 = DDP(net1)
        # net2 = DDP(net2)
        #conf_penalty = NegEntropy()    
        web_valloader = loader.run('test')
        # imagenet_valloader = loader.run('imagenet')   
        
        best_acc = [0,0]
        for epoch in range(args.num_epochs+1):   
            time_start=time.time()
            lr=args.lr
            if epoch >= 50:
                lr /= 10      
            for param_group in optimizer1.param_groups:
                param_group['lr'] = lr       
            for param_group in optimizer2.param_groups:
                param_group['lr'] = lr              
    
            if epoch<warm_up:  
                warmup_trainloader1 = loader.run('warmup')
                warmup_trainloader2 = loader.run('warmup')
                print("111")
                p1 = mp.Process(target=warmup, args=(epoch,net1,optimizer1,warmup_trainloader1,cuda1,'net1'))                      
                p2 = mp.Process(target=warmup, args=(epoch,net2,optimizer2,warmup_trainloader2,cuda2,'net2'))
                print("222")
                p1.start() 
                p2.start()     
                print("333") 
    
            else:                
                pred1 = (prob1 > args.p_threshold)      
                pred2 = (prob2 > args.p_threshold)      
    
                labeled_trainloader1, unlabeled_trainloader1 = loader.run('train',pred2,prob2) # co-divide
                labeled_trainloader2, unlabeled_trainloader2 = loader.run('train',pred1,prob1) # co-divide
                
                p1 = mp.Process(target=train, args=(epoch,net1,net2_clone,optimizer1,labeled_trainloader1, unlabeled_trainloader1,cuda1,'net1'))                             
                p2 = mp.Process(target=train, args=(epoch,net2,net1_clone,optimizer2,labeled_trainloader2, unlabeled_trainloader2,cuda2,'net2'))
                p1.start()  
                p2.start()               
            p1.join()
            p2.join()
            print("444") 
            net1_clone.load_state_dict(net1.state_dict())
            net2_clone.load_state_dict(net2.state_dict())
            print("3")
            q1 = mp.Queue()
            q2 = mp.Queue()
            p1 = mp.Process(target=test, args=(epoch,net1,net2_clone,web_valloader,cuda1,q1)) 
            p2 = mp.Process(target=test, args=(epoch,net1_clone,net2,web_valloader,cuda2,q2))                
            # p2 = mp.Process(target=test, args=(epoch,net1_clone,net2,imagenet_valloader,cuda2,q2))
            print("4")
            p1.start()   
            p2.start()
            
            web_acc = q1.get()
            print("5")
            # imagenet_acc = q2.get()
            k = 0
            if web_acc[k] > best_acc[k]:
                best_acc[k] = web_acc[k]
                print('| Saving Best Net%d ...'%k)
                save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
                torch.save(net1.state_dict(), save_point)
            k = 1
            if web_acc[k] > best_acc[k]:
                best_acc[k] = web_acc[k]
                print('| Saving Best Net%d ...'%k)
                save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
                torch.save(net2.state_dict(), save_point)
            
            p1.join()
            # p2.join()        
            time_end=time.time()
            print("
    | Test Epoch #%d	 WebVision Acc: %.2f%% (%.2f%%) 
    "%(epoch,web_acc[0],web_acc[1]))  
            test_log.write('Epoch:%d 	 WebVision Acc: %.2f%% (%.2f%%)	 %.2f
    '%(epoch,web_acc[0],web_acc[1],time_end-time_start))
            test_log.flush()  
            
            eval_loader1 = loader.run('eval_train')          
            eval_loader2 = loader.run('eval_train')       
            q1 = mp.Queue()
            q2 = mp.Queue()
            p1 = mp.Process(target=eval_train, args=(eval_loader1,net1,cuda1,'net1',q1))                
            p2 = mp.Process(target=eval_train, args=(eval_loader2,net2,cuda2,'net2',q2))
            print("6")
            p1.start()   
            p2.start()
            
            prob1 = q1.get()
            prob2 = q2.get()
            
            p1.join()
            p2.join()
    

      

  • 相关阅读:
    nginx:安装成windows服务
    org.aspectj.apache.bcel.classfile.ClassFormatException: Invalid byte tag in constant pool: 18
    数据库中间件
    架构策略
    谈判
    设计模式 总结 常用10种
    08 状态模式 state
    07 策略模式 strategy
    06 命令模式(不用)
    05 观察者模式 Observer
  • 原文地址:https://www.cnblogs.com/ziytong/p/14746036.html
Copyright © 2011-2022 走看看