import logging import argparse import sys import os from glob import glob from PIL import Image from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset,DataLoader,random_split from torch.utils.tensorboard import SummaryWriter from torch import optim from torch.autograd import Function import numpy as np def get_args(): parser=argparse.ArgumentParser(description='the unet training args.',formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-e', '--epochs',metavar='E',type=int,default=5,help='nums of epochs',dest='epochs') parser.add_argument('-b', '--batch-size',metavar='B',type=int,default=2,help='batch size',dest='batchsize') parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001,help='Learning rate', dest='lr') parser.add_argument('-f', '--load', dest='load', type=str, default=False,help='Load model from a .pth file') parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,help='Downscaling factor of the images') parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,help='Percent of the data that is used as validation (0-100)') return parser.parse_args() #network define. class DoubleConv(nn.Module): """(convolution -> [BN] -> ReLU) * 2""" def __init__(self,in_channels,out_channels,mid_channels=None): super(DoubleConv, self).__init__() if not mid_channels: mid_channels=out_channels self.double_conv=nn.Sequential( nn.Conv2d(in_channels,mid_channels,kernel_size=3,padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels,out_channels,kernel_size=3,padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self,x): return self.double_conv(x) class Down(nn.Module): def __init__(self,in_channels,out_channels): super(Down,self).__init__() self.maxpool_conv=nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels,out_channels) ) def forward(self,x): return self.maxpool_conv(x) class Up(nn.Module): def __init__(self,in_channels, out_channels, bilinear=True): super(Up,self).__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) #stride means mutiples self.conv = DoubleConv(in_channels, out_channels) def forward(self,x1,x2):# merge x1,x2 x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class UNet(nn.Module): '''UNet network architecture''' def __init__(self,n_channels,n_classes,bilinear=True): super(UNet,self).__init__() self.n_channels=n_channels self.n_classes=n_classes self.bilinear=bilinear self.inc=DoubleConv(n_channels,64) self.down1=Down(64,128) self.down2=Down(128,256) self.down3=Down(256,512) factor=2 if bilinear else 1 self.down4=Down(512,1024//factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self,x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits #data load dir_img = 'data/imgs/' dir_mask = 'data/masks/' dir_checkpoint = 'my_checkpoints/' class BasicDataset(Dataset): def __init__(self,imgs_dir,masks_dir,scale=1,mask_suffix='_mask'): self.imgs_dir=imgs_dir self.masks_dir=masks_dir self.scale=scale self.masks_suffix=mask_suffix assert 0<scale<=1, 'Scale must be between 0 and 1.' self.ids=[os.path.splitext(file)[0] for file in os.listdir(imgs_dir)] logging.info(f'Creating dataset with {len(self.ids)} examples') def __len__(self): return len(self.ids) def __getitem__(self, i): idx=self.ids[i] mask_file=glob(self.masks_dir+idx+self.masks_suffix+'.*') img_file=glob(self.imgs_dir+idx+'.*') assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {idx}: {mask_file}' assert len(img_file) == 1, f'Either no image or multiple images found for the ID {idx}: {img_file}' mask = Image.open(mask_file[0]) img = Image.open(img_file[0]) assert img.size == mask.size, f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}' img = self.preprocess(img, self.scale) mask = self.preprocess(mask, self.scale) return { 'image': torch.from_numpy(img).type(torch.FloatTensor), 'mask': torch.from_numpy(mask).type(torch.FloatTensor) } @classmethod def preprocess(cls,pil_img,scale): w,h=pil_img.size newW,newH=int(scale*w),int(scale*h) assert newW>0 and newH>0, 'scale is too small' pil_img=pil_img.resize((newW,newH)) img_nd=np.array(pil_img) if len(img_nd.shape)==2: img_nd=np.expand_dims(img_nd,axis=2) #HCW to CHW img_trans=img_nd.transpose((2,0,1)) if img_trans.max()>1: img_trans=img_trans/255. return img_trans class DiceCoeff(Function): """Dice coeff for individual examples""" def forward(self, input, target): '''交并比计算''' self.save_for_backward(input, target) #? eps = 0.0001 self.inter = torch.dot(input.view(-1), target.view(-1)) #做对应的点积,并求和 self.union = torch.sum(input) + torch.sum(target) + eps #做求和,scalar. t = (2 * self.inter.float() + eps) / self.union.float() return t # This function has only a single output, so it gets only one gradient # backward()的输入和输出的个数就是forward()函数的输出和输入的个数!!! # 其中,backward()输入表示关于forward()输出的梯度(计算图中上一节点的梯度), # backward()的输出表示关于forward()的输入的梯度。 # 在输入不需要梯度时(通过查看needs_input_grad参数)或者不可导时,可以返回None def backward(self, grad_output): input, target = self.saved_variables #与line 10 呼应? grad_input = grad_target = None if self.needs_input_grad[0]: #needs_input_grad是布尔值的元组,指示每个输入是否需要梯度计算 #自定义反向梯度,保证形状一样就行! grad_input = grad_output * 2 * (target * self.union - self.inter) / (self.union * self.union) if self.needs_input_grad[1]: grad_target = None #标签没有反向梯度 return grad_input, grad_target #与forward 对应? def dice_coeff(input, target): """Dice coeff for batches""" if input.is_cuda: s = torch.FloatTensor(1).cuda().zero_() else: s = torch.FloatTensor(1).zero_() for i, c in enumerate(zip(input, target)): s = s + DiceCoeff().forward(c[0], c[1]) #c[0],c[1]可以理解为单通道图像 return s / (i + 1) def eval(net,loader,device): net.eval() n_val = len(loader) # the number of batch tot = 0 with tqdm(total=n_val,desc='Validation round', unit='batch', leave=False) as pbar: for batch in loader: imgs, true_masks=batch['image'],batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) with torch.no_grad(): mask_pred=net(imgs) if net.n_classes > 1: tot += F.cross_entropy(mask_pred, true_masks).item() else: pred = torch.sigmoid(mask_pred) pred = (pred > 0.5).float() tot += dice_coeff(pred, true_masks).item() pbar.update() net.train() return tot/n_val def train(net, device, epochs=5, batch_size=1, lr=0.001, val_percent=0.1, save_cp=True, img_scale=0.5): #数据读取部分 dataset = BasicDataset(dir_img, dir_mask, img_scale) #随机划分train,val n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train,val=random_split(dataset,[n_train,n_val]) train_loader=DataLoader(train,batch_size=batch_size,shuffle=True,num_workers=8,pin_memory=True) val_loader=DataLoader(val,batch_size=batch_size,shuffle=False,num_workers=8,pin_memory=True,drop_last=True) #tensorboard writer=SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') #comment: append suffix to log_dir. global_step=0 optimizer=optim.RMSprop(net.parameters(),lr=lr,weight_decay=1e-8,momentum=0.9) scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min' if net.n_classes >1 else 'max',patience=2) criterion=nn.CrossEntropyLoss() if net.n_classes>1 else nn.BCEWithLogitsLoss() logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_val} Checkpoints: {save_cp} Device: {device.type} Images scaling: {img_scale} ''') for epoch in range(epochs): net.train() epoch_loss=0 with tqdm(total=n_train,desc=f'Epoch {epoch + 1}/{epochs}', unit='img',leave=False) as pbar: for batch in train_loader: imgs=batch['image'] true_masks=batch['mask'] assert imgs.shape[1] == net.n_channels, f'Network has been defined with {net.n_channels} input channels, ' f'but loaded images have {imgs.shape[1]} channels. Please check that ' 'the images are loaded correctly.' imgs=imgs.to(device,dtype=torch.float32) mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=mask_type) masks_pred = net(imgs) loss = criterion(masks_pred, true_masks) epoch_loss += loss.item() writer.add_scalar('Loss/train', loss.item(), global_step) pbar.set_postfix(**{'loss (batch)':loss.item()}) #给tqdm新添加可视化batch loss optimizer.zero_grad() loss.backward() nn.utils.clip_grad_value_(net.parameters(),0.1) #梯度裁剪 optimizer.step() pbar.update(imgs.shape[0]) #每多少进行一次刷新 global_step+=1 if global_step % (n_train//(10*batch_size))==0: #每train 10个batchsize进行一次val for tag,value in net.named_parameters(): tag=tag.replace('.','/') writer.add_histogram('weights/'+tag,value.data.cpu().numpy(),global_step) writer.add_histogram('grads/'+tag,value.grad.data.cpu().numpy(),global_step) val_score=eval(net,val_loader,device) #validation step.每次对验证集所有数据进行验证. scheduler.step(val_score) # 此处真正使用scheduler! writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) else: logging.info('Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step) writer.add_images('images', imgs, global_step) if net.n_classes == 1: writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) if save_cp: try: os.mkdir(dir_checkpoint) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(),dir_checkpoint + f'CP_epoch{epoch + 1}.pth') logging.info(f'Checkpoint {epoch + 1} saved !') writer.close() if __name__ == '__main__': logging.basicConfig(level=logging.INFO,format='%(levelname)s:%(message)s') args=get_args() device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'using device {device}') #define unet net = UNet(n_channels=3, n_classes=1, bilinear=True) net.to(device=device) logging.info(f'Network: ' f' {net.n_channels} input channels ' f' {net.n_classes} output channels (classes) ' f' {"Bilinear" if net.bilinear else "Transposed conv"} upscaling') #写network时便于调试网络结构 # x=torch.ones(1,3,256,256) # tmp=net(x) #finetune if args.load: net.load_state_dict( torch.load(args.load, map_location=device) ) logging.info(f'Model loaded from {args.load}') try: train(net=net, epochs=args.epochs, batch_size=args.batchsize, lr=args.lr, device=device, img_scale=args.scale, val_percent=args.val/100) except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') logging.info('Saved interrupt') try: sys.exit(0) except SystemExit: os._exit(0)