学习:https://zhuanlan.zhihu.com/p/136372142
DDP:
from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.distributed as dist
import torch.utils.data
import torch.utils.data.distributed
import random
import os
import sys
import argparse
import numpy as np
from InceptionResNetV2 import *
from swin_transformer import *
from sklearn.mixture import GaussianMixture
import dataloader_aliproduct as dataloader
import torchnet
import time
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
import apex
parser = argparse.ArgumentParser(description='PyTorch WebVision Training')
parser.add_argument('--batch_size', default=64, type=int, help='train batchsize')
parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--id', default='',type=str)
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid', default=0, type=int)
parser.add_argument('--num_class', default=50030, type=int)
parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset')
parser.add_argument('--opt-level', default='O1', type=str)
parser.add_argument("--local_rank", default=4, type=int)
parser.add_argument('--world-size', default=2, type=int,
help='number of nodes for distributed training')
# parser.add_argument('--rank', default=3, type=int,
# help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--gpu', default=-1, type=int,
help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
# # cudnn.benchmark = True
# args = parser.parse_args()
# # torch.cuda.set_device(args.gpuid)
# assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
args = parser.parse_args()
random.seed(args.seed)
# torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
ngpus_per_node = torch.cuda.device_count()
if args.gpu is None:
print("must specify GPU number")
#return
print("Use GPU: {} for training".format(args.gpu))
torch.cuda.set_device(args.local_rank)
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
dist.init_process_group(backend='nccl', init_method='env://')
print("after dist init")
# Training
def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader):
net.train()
net2.eval() #fix one network and train the other
unlabeled_train_iter = iter(unlabeled_trainloader)
num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
try:
inputs_u, inputs_u2 = unlabeled_train_iter.next()
except:
unlabeled_train_iter = iter(unlabeled_trainloader)
inputs_u, inputs_u2 = unlabeled_train_iter.next()
batch_size = inputs_x.size(0)
# Transform label to one-hot
labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)
w_x = w_x.view(-1,1).type(torch.FloatTensor)
inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()
with torch.no_grad():
# label co-guessing of unlabeled samples
outputs_u11 = net(inputs_u)
outputs_u12 = net(inputs_u2)
outputs_u21 = net2(inputs_u)
outputs_u22 = net2(inputs_u2)
pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
ptu = pu**(1/args.T) # temparature sharpening
targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
targets_u = targets_u.detach()
# label refinement of labeled samples
outputs_x = net(inputs_x)
outputs_x2 = net(inputs_x2)
px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
px = w_x*labels_x + (1-w_x)*px
ptx = px**(1/args.T) # temparature sharpening
targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
targets_x = targets_x.detach()
# mixmatch
l = np.random.beta(args.alpha, args.alpha)
l = max(l, 1-l)
all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
idx = torch.randperm(all_inputs.size(0))
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]
mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
logits = net(mixed_input)
Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
prior = torch.ones(args.num_class)/args.num_class
prior = prior.cuda()
pred_mean = torch.softmax(logits, dim=1).mean(0)
penalty = torch.sum(prior*torch.log(prior/pred_mean))
loss = Lx + penalty
# compute gradient and do SGD step
optimizer.zero_grad()
# loss.backward()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if (batch_idx+1)%100 == 1:
sys.stdout.write('
')
sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d] Labeled loss: %.2f'
%(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
sys.stdout.flush()
def warmup(epoch,net,optimizer,dataloader):
start1 = time.time()
net.train()
num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
for batch_idx, (inputs, labels, path) in enumerate(dataloader):
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = net(inputs)
loss = CEloss(outputs, labels)
#penalty = conf_penalty(outputs)
L = loss #+ penalty
# L.backward()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if (batch_idx+1)%100 == 1:
end1 = time.time()
sys.stdout.write('
')
sys.stdout.write('%s | Epoch [%3d/%3d] Iter[%4d/%4d] CE-loss: %.4f %.4f
'
%(args.id, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item(),end1 - start1))
sys.stdout.flush()
def test(epoch,net1,net2,test_loader):
acc_meter.reset()
net1.eval()
net2.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs1 = net1(inputs)
outputs2 = net2(inputs)
outputs = outputs1+outputs2
_, predicted = torch.max(outputs, 1)
acc_meter.add(outputs,targets)
accs = acc_meter.value()
return accs
def eval_train(model,all_loss):
model.eval()
num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1
losses = torch.zeros(len(eval_loader.dataset))
with torch.no_grad():
for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs = model(inputs)
loss = CE(outputs, targets)
for b in range(inputs.size(0)):
losses[index[b]]=loss[b]
if (batch_idx+1)%100 == 1:
sys.stdout.write('
')
sys.stdout.write('| Evaluating loss Iter[%3d/%3d] ' %(batch_idx,num_iter))
sys.stdout.flush()
losses = (losses-losses.min())/(losses.max()-losses.min())
all_loss.append(losses)
# fit a two-component GMM to the loss
input_loss = losses.reshape(-1,1)
gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
gmm.fit(input_loss)
prob = gmm.predict_proba(input_loss)
prob = prob[:,gmm.means_.argmin()]
return prob,all_loss
def linear_rampup(current, warm_up, rampup_length=16):
current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
return args.lambda_u*float(current)
class SemiLoss(object):
def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
probs_u = torch.softmax(outputs_u, dim=1)
Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
Lu = torch.mean((probs_u - targets_u)**2)
return Lx, Lu, linear_rampup(epoch,warm_up)
class NegEntropy(object):
def __call__(self,outputs):
probs = torch.softmax(outputs, dim=1)
return torch.mean(torch.sum(probs.log()*probs, dim=1))
def create_model():
model = swin_base_patch4_window7_224_in22k()#InceptionResNetV2(num_classes=args.num_class) # #
model = model.cuda(args.local_rank)
return model
stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w')
test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w')
warm_up=1
loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_workers=5,root_dir=args.data_path,log=stats_log, num_class=args.num_class)
print('| Building net')
net1 = create_model()
net2 = create_model()
cudnn.benchmark = True
criterion = SemiLoss()
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
net1, optimizer1 = amp.initialize(net1, optimizer1,
opt_level=args.opt_level)
net2, optimizer2 = amp.initialize(net2, optimizer2,
opt_level=args.opt_level
)
net1 = DDP(net1)
net2 = DDP(net2)#, device_ids=[args.local_rank],output_device=args.local_rank, find_unused_parameters=True)
CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
conf_penalty = NegEntropy()
# print("1")
all_loss = [[],[]] # save the history of losses from two networks
acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
best_acc = [0,0]
for epoch in range(args.num_epochs+1):
lr=args.lr
if epoch >= 40:
lr /= 10
for param_group in optimizer1.param_groups:
param_group['lr'] = lr
for param_group in optimizer2.param_groups:
param_group['lr'] = lr
eval_loader = loader.run('eval_train')
web_valloader = loader.run('test')
# imagenet_valloader = loader.run('imagenet')
if epoch<warm_up:
warmup_trainloader = loader.run('warmup')
print('Warmup Net1')
warmup(epoch,net1,optimizer1,warmup_trainloader)
print('
Warmup Net2')
warmup(epoch,net2,optimizer2,warmup_trainloader)
else:
pred1 = (prob1 > args.p_threshold)
pred2 = (prob2 > args.p_threshold)
print('Train Net1')
labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide
train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1
print('
Train Net2')
labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide
train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2
web_acc = test(epoch,net1,net2,web_valloader)
k = 0
if web_acc[k] > best_acc[k]:
best_acc[k] = web_acc[k]
print('| Saving Best Net%d ...'%k)
save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
torch.save(net1.state_dict(), save_point)
k = 1
if web_acc[k] > best_acc[k]:
best_acc[k] = web_acc[k]
print('| Saving Best Net%d ...'%k)
save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
torch.save(net2.state_dict(), save_point)
# imagenet_acc = test(epoch,net1,net2,imagenet_valloader)
# print("
| Test Epoch #%d WebVision Acc: %.2f%% (%.2f%%) ImageNet Acc: %.2f%% (%.2f%%)
"%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
# test_log.write('Epoch:%d WebVision Acc: %.2f%% (%.2f%%) ImageNet Acc: %.2f%% (%.2f%%)
'%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
# test_log.flush()
print("
| Test Epoch #%d WebVision Acc: %.2f%% (%.2f%%)
"%(epoch,web_acc[0],web_acc[1]))
test_log.write('Epoch:%d WebVision Acc: %.2f%% (%.2f%%)
'%(epoch,web_acc[0],web_acc[1]))
test_log.flush()
print('
==== net 1 evaluate training data loss ====')
prob1,all_loss[0]=eval_train(net1,all_loss[0])
print('
==== net 2 evaluate training data loss ====')
prob2,all_loss[1]=eval_train(net2,all_loss[1])
torch.save(all_loss,'./checkpoint/%s.pth.tar'%(args.id))
multiprocessing:
from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os
import sys
import argparse
import numpy as np
# from InceptionResNetV2 import *
from swin_transformer import *
from sklearn.mixture import GaussianMixture
# import dataloader_webvision as dataloader
import dataloader_aliproduct as dataloader
import torchnet
import torch.multiprocessing as mp
import time
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
import apex
parser = argparse.ArgumentParser(description='PyTorch WebVision Parallel Training')
parser.add_argument('--batch_size', default=96, type=int, help='train batchsize')
parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--id', default='',type=str)
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid1', default=1, type=int)
parser.add_argument('--gpuid2', default=2, type=int)
parser.add_argument('--num_class', default=50030, type=int)
parser.add_argument('--data_path', default='./dataset/', type=str, help='path to dataset')
parser.add_argument('--opt-level', default='O2', type=str)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = '%s,%s'%(args.gpuid1,args.gpuid2)
random.seed(args.seed)
cuda1 = torch.device('cuda:0')
cuda2 = torch.device('cuda:1')
# Training
def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader,device,whichnet):
criterion = SemiLoss()
net.train()
net2.eval() #fix one network and train the other
unlabeled_train_iter = iter(unlabeled_trainloader)
num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
net, optimizer = amp.initialize(net, optimizer,
opt_level=args.opt_level
)
net2, optimizer = amp.initialize(net2, optimizer,
opt_level=args.opt_level
)
start1 = time.time()
for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(labeled_trainloader):
try:
inputs_u, inputs_u2 = unlabeled_train_iter.next()
except:
unlabeled_train_iter = iter(unlabeled_trainloader)
inputs_u, inputs_u2 = unlabeled_train_iter.next()
batch_size = inputs_x.size(0)
# Transform label to one-hot
labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)
w_x = w_x.view(-1,1).type(torch.FloatTensor)
inputs_x, inputs_x2, labels_x, w_x = inputs_x.to(device,non_blocking=True), inputs_x2.to(device,non_blocking=True), labels_x.to(device,non_blocking=True), w_x.to(device,non_blocking=True)
inputs_u, inputs_u2 = inputs_u.to(device), inputs_u2.to(device)
with torch.no_grad():
# label co-guessing of unlabeled samples
outputs_u11 = net(inputs_u)
outputs_u12 = net(inputs_u2)
outputs_u21 = net2(inputs_u)
outputs_u22 = net2(inputs_u2)
pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4
ptu = pu**(1/args.T) # temparature sharpening
targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
targets_u = targets_u.detach()
# label refinement of labeled samples
outputs_x = net(inputs_x)
outputs_x2 = net(inputs_x2)
px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
px = w_x*labels_x + (1-w_x)*px
ptx = px**(1/args.T) # temparature sharpening
targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize
targets_x = targets_x.detach()
# mixmatch
l = np.random.beta(args.alpha, args.alpha)
l = max(l, 1-l)
all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
idx = torch.randperm(all_inputs.size(0))
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]
mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
logits = net(mixed_input)
Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
prior = torch.ones(args.num_class)/args.num_class
prior = prior.to(device)
pred_mean = torch.softmax(logits, dim=1).mean(0)
penalty = torch.sum(prior*torch.log(prior/pred_mean))
loss = Lx + penalty
# compute gradient and do SGD step
optimizer.zero_grad()
# loss.backward()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if (batch_idx+1)%100 == 1:
sys.stdout.write('
')
sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d] Labeled loss: %.2f'
%(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
sys.stdout.flush()
def warmup(epoch,net,optimizer,dataloader,device,whichnet):
CEloss = nn.CrossEntropyLoss()
acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
start1 = time.time()
net.train()
num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
net, optimizer = amp.initialize(net, optimizer,
opt_level=args.opt_level
)
for batch_idx, (inputs, labels, path) in enumerate(dataloader):
if batch_idx < 102:
inputs, labels = inputs.to(device), labels.to(device,non_blocking=True)
optimizer.zero_grad()
outputs = net(inputs)
loss = CEloss(outputs, labels)
#penalty = conf_penalty(outputs)
L = loss #+ penalty
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# L.backward()
optimizer.step()
if (batch_idx+1)%100 == 1:
end1 = time.time()
sys.stdout.write('
')
sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d] CE-loss: %.4f %.4f'
%(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item(),end1-start1))
sys.stdout.flush()
def test(epoch,net1,net2,test_loader,device,queue):
acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
acc_meter.reset()
net1.eval()
net2.eval()
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.to(device), targets.to(device,non_blocking=True)
outputs1 = net1(inputs)
outputs2 = net2(inputs)
outputs = outputs1+outputs2
_, predicted = torch.max(outputs, 1)
acc_meter.add(outputs,targets)
accs = acc_meter.value()
queue.put(accs)
def eval_train(eval_loader,model,device,whichnet,queue):
CE = nn.CrossEntropyLoss(reduction='none')
model.eval()
num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1
losses = torch.zeros(len(eval_loader.dataset))
with torch.no_grad():
for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
inputs, targets = inputs.to(device), targets.to(device,non_blocking=True)
outputs = model(inputs)
loss = CE(outputs, targets)
for b in range(inputs.size(0)):
losses[index[b]]=loss[b]
if (batch_idx+1)%100 == 1:
sys.stdout.write('
')
sys.stdout.write('|%s Evaluating loss Iter[%3d/%3d] ' %(whichnet,batch_idx,num_iter))
sys.stdout.flush()
losses = (losses-losses.min())/(losses.max()-losses.min())
# fit a two-component GMM to the loss
input_loss = losses.reshape(-1,1)
gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=1e-3)
gmm.fit(input_loss)
prob = gmm.predict_proba(input_loss)
prob = prob[:,gmm.means_.argmin()]
queue.put(prob)
def linear_rampup(current, warm_up, rampup_length=16):
current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
return args.lambda_u*float(current)
class SemiLoss(object):
def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
probs_u = torch.softmax(outputs_u, dim=1)
Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
Lu = torch.mean((probs_u - targets_u)**2)
return Lx, Lu, linear_rampup(epoch,warm_up)
class NegEntropy(object):
def __call__(self,outputs):
probs = torch.softmax(outputs, dim=1)
return torch.mean(torch.sum(probs.log()*probs, dim=1))
def create_model(device):
#model = InceptionResNetV2(num_classes=args.num_class)
model = swin_base_patch4_window7_224_in22k()
model = model.to(device)
return model
if __name__ == "__main__":
mp.set_start_method('spawn')
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w')
test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w')
warm_up=1
loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_class = args.num_class,num_workers=8,root_dir=args.data_path,log=stats_log)
print('| Building net')
net1 = create_model(cuda1)
net2 = create_model(cuda2)
net1_clone = create_model(cuda2)
net2_clone = create_model(cuda1)
cudnn.benchmark = True
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# net1, optimizer1 = amp.initialize(net1, optimizer1,
# opt_level=args.opt_level
# )
# net2, optimizer2 = amp.initialize(net2, optimizer2,
# opt_level=args.opt_level
# )
net1_clone = create_model(cuda2)
net2_clone = create_model(cuda1)
# net1 = DDP(net1)
# net2 = DDP(net2)
#conf_penalty = NegEntropy()
web_valloader = loader.run('test')
# imagenet_valloader = loader.run('imagenet')
best_acc = [0,0]
for epoch in range(args.num_epochs+1):
time_start=time.time()
lr=args.lr
if epoch >= 50:
lr /= 10
for param_group in optimizer1.param_groups:
param_group['lr'] = lr
for param_group in optimizer2.param_groups:
param_group['lr'] = lr
if epoch<warm_up:
warmup_trainloader1 = loader.run('warmup')
warmup_trainloader2 = loader.run('warmup')
print("111")
p1 = mp.Process(target=warmup, args=(epoch,net1,optimizer1,warmup_trainloader1,cuda1,'net1'))
p2 = mp.Process(target=warmup, args=(epoch,net2,optimizer2,warmup_trainloader2,cuda2,'net2'))
print("222")
p1.start()
p2.start()
print("333")
else:
pred1 = (prob1 > args.p_threshold)
pred2 = (prob2 > args.p_threshold)
labeled_trainloader1, unlabeled_trainloader1 = loader.run('train',pred2,prob2) # co-divide
labeled_trainloader2, unlabeled_trainloader2 = loader.run('train',pred1,prob1) # co-divide
p1 = mp.Process(target=train, args=(epoch,net1,net2_clone,optimizer1,labeled_trainloader1, unlabeled_trainloader1,cuda1,'net1'))
p2 = mp.Process(target=train, args=(epoch,net2,net1_clone,optimizer2,labeled_trainloader2, unlabeled_trainloader2,cuda2,'net2'))
p1.start()
p2.start()
p1.join()
p2.join()
print("444")
net1_clone.load_state_dict(net1.state_dict())
net2_clone.load_state_dict(net2.state_dict())
print("3")
q1 = mp.Queue()
q2 = mp.Queue()
p1 = mp.Process(target=test, args=(epoch,net1,net2_clone,web_valloader,cuda1,q1))
p2 = mp.Process(target=test, args=(epoch,net1_clone,net2,web_valloader,cuda2,q2))
# p2 = mp.Process(target=test, args=(epoch,net1_clone,net2,imagenet_valloader,cuda2,q2))
print("4")
p1.start()
p2.start()
web_acc = q1.get()
print("5")
# imagenet_acc = q2.get()
k = 0
if web_acc[k] > best_acc[k]:
best_acc[k] = web_acc[k]
print('| Saving Best Net%d ...'%k)
save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
torch.save(net1.state_dict(), save_point)
k = 1
if web_acc[k] > best_acc[k]:
best_acc[k] = web_acc[k]
print('| Saving Best Net%d ...'%k)
save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
torch.save(net2.state_dict(), save_point)
p1.join()
# p2.join()
time_end=time.time()
print("
| Test Epoch #%d WebVision Acc: %.2f%% (%.2f%%)
"%(epoch,web_acc[0],web_acc[1]))
test_log.write('Epoch:%d WebVision Acc: %.2f%% (%.2f%%) %.2f
'%(epoch,web_acc[0],web_acc[1],time_end-time_start))
test_log.flush()
eval_loader1 = loader.run('eval_train')
eval_loader2 = loader.run('eval_train')
q1 = mp.Queue()
q2 = mp.Queue()
p1 = mp.Process(target=eval_train, args=(eval_loader1,net1,cuda1,'net1',q1))
p2 = mp.Process(target=eval_train, args=(eval_loader2,net2,cuda2,'net2',q2))
print("6")
p1.start()
p2.start()
prob1 = q1.get()
prob2 = q2.get()
p1.join()
p2.join()