zoukankan      html  css  js  c++  java
  • 手敲pytorch版unet

    import logging
    import argparse
    import sys
    import os
    from glob import glob
    from PIL import Image
    from tqdm import tqdm
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset,DataLoader,random_split
    from torch.utils.tensorboard import SummaryWriter
    from torch import optim
    from torch.autograd import Function
    import numpy as np
    
    def get_args():
        parser=argparse.ArgumentParser(description='the unet training args.',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('-e', '--epochs',metavar='E',type=int,default=5,help='nums of epochs',dest='epochs')
        parser.add_argument('-b', '--batch-size',metavar='B',type=int,default=2,help='batch size',dest='batchsize')
        parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001,help='Learning rate', dest='lr')
        parser.add_argument('-f', '--load', dest='load', type=str, default=False,help='Load model from a .pth file')
        parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,help='Downscaling factor of the images')
        parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,help='Percent of the data that is used as validation (0-100)')
        return parser.parse_args()
    
    #network define.
    class DoubleConv(nn.Module):
        """(convolution -> [BN] -> ReLU) * 2"""
        def __init__(self,in_channels,out_channels,mid_channels=None):
            super(DoubleConv, self).__init__()
            if not mid_channels:
                mid_channels=out_channels
            self.double_conv=nn.Sequential(
                nn.Conv2d(in_channels,mid_channels,kernel_size=3,padding=1),
                nn.BatchNorm2d(mid_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_channels,out_channels,kernel_size=3,padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
    
        def forward(self,x):
            return self.double_conv(x)
    
    class Down(nn.Module):
        def __init__(self,in_channels,out_channels):
            super(Down,self).__init__()
            self.maxpool_conv=nn.Sequential(
                nn.MaxPool2d(2),
                DoubleConv(in_channels,out_channels)
            )
    
        def forward(self,x):
            return self.maxpool_conv(x)
    
    class Up(nn.Module):
        def __init__(self,in_channels, out_channels, bilinear=True):
            super(Up,self).__init__()
            # if bilinear, use the normal convolutions to reduce the number of channels
            if bilinear:
                self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
                self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
            else:
                self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) #stride means mutiples
                self.conv = DoubleConv(in_channels, out_channels)
    
        def forward(self,x1,x2):# merge x1,x2
            x1 = self.up(x1)
            # input is CHW
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]
    
            x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2])
            x = torch.cat([x2, x1], dim=1)
            return self.conv(x)
    
    class OutConv(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(OutConv, self).__init__()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    
        def forward(self, x):
            return self.conv(x)
    
    class UNet(nn.Module):
        '''UNet network architecture'''
        def __init__(self,n_channels,n_classes,bilinear=True):
            super(UNet,self).__init__()
            self.n_channels=n_channels
            self.n_classes=n_classes
            self.bilinear=bilinear
    
            self.inc=DoubleConv(n_channels,64)
            self.down1=Down(64,128)
            self.down2=Down(128,256)
            self.down3=Down(256,512)
            factor=2 if bilinear else 1
            self.down4=Down(512,1024//factor)
            self.up1 = Up(1024, 512 // factor, bilinear)
            self.up2 = Up(512, 256 // factor, bilinear)
            self.up3 = Up(256, 128 // factor, bilinear)
            self.up4 = Up(128, 64, bilinear)
            self.outc = OutConv(64, n_classes)
    
        def forward(self,x):
            x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)
            x5 = self.down4(x4)
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1)
            logits = self.outc(x)
            return logits
    
    #data load
    dir_img = 'data/imgs/'
    dir_mask = 'data/masks/'
    dir_checkpoint = 'my_checkpoints/'
    
    class BasicDataset(Dataset):
        def __init__(self,imgs_dir,masks_dir,scale=1,mask_suffix='_mask'):
            self.imgs_dir=imgs_dir
            self.masks_dir=masks_dir
            self.scale=scale
            self.masks_suffix=mask_suffix
            assert 0<scale<=1, 'Scale must be between 0 and 1.'
            self.ids=[os.path.splitext(file)[0] for file in os.listdir(imgs_dir)]
            logging.info(f'Creating dataset with {len(self.ids)} examples')
    
        def __len__(self):
            return len(self.ids)
    
        def __getitem__(self, i):
            idx=self.ids[i]
            mask_file=glob(self.masks_dir+idx+self.masks_suffix+'.*')
            img_file=glob(self.imgs_dir+idx+'.*')
            assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
            assert len(img_file) == 1, f'Either no image or multiple images found for the ID {idx}: {img_file}'
            mask = Image.open(mask_file[0])
            img = Image.open(img_file[0])
            assert img.size == mask.size, f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
            img = self.preprocess(img, self.scale)
            mask = self.preprocess(mask, self.scale)
    
            return {
                'image': torch.from_numpy(img).type(torch.FloatTensor),
                'mask': torch.from_numpy(mask).type(torch.FloatTensor)
            }
    
        @classmethod
        def preprocess(cls,pil_img,scale):
            w,h=pil_img.size
            newW,newH=int(scale*w),int(scale*h)
            assert newW>0 and newH>0, 'scale is too small'
            pil_img=pil_img.resize((newW,newH))
            img_nd=np.array(pil_img)
            if len(img_nd.shape)==2:
                img_nd=np.expand_dims(img_nd,axis=2)
            #HCW to CHW
            img_trans=img_nd.transpose((2,0,1))
            if img_trans.max()>1:
                img_trans=img_trans/255.
    
            return img_trans
    
    class DiceCoeff(Function):
        """Dice coeff for individual examples"""
    
        def forward(self, input, target):
            '''交并比计算'''
            self.save_for_backward(input, target) #?
            eps = 0.0001
            self.inter = torch.dot(input.view(-1), target.view(-1)) #做对应的点积,并求和
            self.union = torch.sum(input) + torch.sum(target) + eps #做求和,scalar.
            t = (2 * self.inter.float() + eps) / self.union.float()
            return t
    
        # This function has only a single output, so it gets only one gradient
        # backward()的输入和输出的个数就是forward()函数的输出和输入的个数!!!
        # 其中,backward()输入表示关于forward()输出的梯度(计算图中上一节点的梯度),
        # backward()的输出表示关于forward()的输入的梯度。
        # 在输入不需要梯度时(通过查看needs_input_grad参数)或者不可导时,可以返回None
        def backward(self, grad_output):
            input, target = self.saved_variables #与line 10 呼应?
            grad_input = grad_target = None
            if self.needs_input_grad[0]: #needs_input_grad是布尔值的元组,指示每个输入是否需要梯度计算
                #自定义反向梯度,保证形状一样就行!
                grad_input = grad_output * 2 * (target * self.union - self.inter) 
                             / (self.union * self.union)
            if self.needs_input_grad[1]:
                grad_target = None #标签没有反向梯度
            return grad_input, grad_target #与forward 对应?
    
    def dice_coeff(input, target):
        """Dice coeff for batches"""
        if input.is_cuda:
            s = torch.FloatTensor(1).cuda().zero_()
        else:
            s = torch.FloatTensor(1).zero_()
    
        for i, c in enumerate(zip(input, target)):
            s = s + DiceCoeff().forward(c[0], c[1]) #c[0],c[1]可以理解为单通道图像
        return s / (i + 1)
    
    def eval(net,loader,device):
        net.eval()
        n_val = len(loader)  # the number of batch
        tot = 0
    
        with tqdm(total=n_val,desc='Validation round', unit='batch', leave=False) as pbar:
            for batch in loader:
                imgs, true_masks=batch['image'],batch['mask']
                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)
    
                with torch.no_grad():
                    mask_pred=net(imgs)
    
                if net.n_classes > 1:
                    tot += F.cross_entropy(mask_pred, true_masks).item()
                else:
                    pred = torch.sigmoid(mask_pred)
                    pred = (pred > 0.5).float()
                    tot += dice_coeff(pred, true_masks).item()
                pbar.update()
    
        net.train()
        return tot/n_val
    
    def train(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):
        #数据读取部分
        dataset = BasicDataset(dir_img, dir_mask, img_scale)
    
        #随机划分train,val
        n_val = int(len(dataset) * val_percent)
        n_train = len(dataset) - n_val
        train,val=random_split(dataset,[n_train,n_val])
        train_loader=DataLoader(train,batch_size=batch_size,shuffle=True,num_workers=8,pin_memory=True)
        val_loader=DataLoader(val,batch_size=batch_size,shuffle=False,num_workers=8,pin_memory=True,drop_last=True)
    
        #tensorboard
        writer=SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') #comment: append suffix to log_dir.
        global_step=0
        optimizer=optim.RMSprop(net.parameters(),lr=lr,weight_decay=1e-8,momentum=0.9)
        scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min' if net.n_classes >1 else 'max',patience=2)
        criterion=nn.CrossEntropyLoss() if net.n_classes>1 else nn.BCEWithLogitsLoss()
    
        logging.info(f'''Starting training:
            Epochs:          {epochs}
            Batch size:      {batch_size}
            Learning rate:   {lr}
            Training size:   {n_train}
            Validation size: {n_val}
            Checkpoints:     {save_cp}
            Device:          {device.type}
            Images scaling:  {img_scale}
        ''')
    
        for epoch in range(epochs):
            net.train()
            epoch_loss=0
    
            with tqdm(total=n_train,desc=f'Epoch {epoch + 1}/{epochs}', unit='img',leave=False) as pbar:
                for batch in train_loader:
                    imgs=batch['image']
                    true_masks=batch['mask']
                    assert imgs.shape[1] == net.n_channels, 
                        f'Network has been defined with {net.n_channels} input channels, ' 
                        f'but loaded images have {imgs.shape[1]} channels. Please check that ' 
                        'the images are loaded correctly.'
    
                    imgs=imgs.to(device,dtype=torch.float32)
                    mask_type = torch.float32 if net.n_classes == 1 else torch.long
                    true_masks = true_masks.to(device=device, dtype=mask_type)
    
                    masks_pred = net(imgs)
                    loss = criterion(masks_pred, true_masks)
                    epoch_loss += loss.item()
                    writer.add_scalar('Loss/train', loss.item(), global_step)
    
                    pbar.set_postfix(**{'loss (batch)':loss.item()}) #给tqdm新添加可视化batch loss
    
                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_value_(net.parameters(),0.1) #梯度裁剪
                    optimizer.step()
    
                    pbar.update(imgs.shape[0]) #每多少进行一次刷新
                    global_step+=1
                    if global_step % (n_train//(10*batch_size))==0: #每train 10个batchsize进行一次val
                        for tag,value in net.named_parameters():
                            tag=tag.replace('.','/')
                            writer.add_histogram('weights/'+tag,value.data.cpu().numpy(),global_step)
                            writer.add_histogram('grads/'+tag,value.grad.data.cpu().numpy(),global_step)
    
                        val_score=eval(net,val_loader,device) #validation step.每次对验证集所有数据进行验证.
                        scheduler.step(val_score)  # 此处真正使用scheduler!
                        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
    
                        if net.n_classes > 1:
                            logging.info('Validation cross entropy: {}'.format(val_score))
                            writer.add_scalar('Loss/test', val_score, global_step)
                        else:
                            logging.info('Validation Dice Coeff: {}'.format(val_score))
                            writer.add_scalar('Dice/test', val_score, global_step)
    
                        writer.add_images('images', imgs, global_step)
                        if net.n_classes == 1:
                            writer.add_images('masks/true', true_masks, global_step)
                            writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
    
            if save_cp:
                try:
                    os.mkdir(dir_checkpoint)
                    logging.info('Created checkpoint directory')
                except OSError:
                    pass
                torch.save(net.state_dict(),dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
                logging.info(f'Checkpoint {epoch + 1} saved !')
    
        writer.close()
    
    if __name__ == '__main__':
        logging.basicConfig(level=logging.INFO,format='%(levelname)s:%(message)s')
        args=get_args()
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logging.info(f'using device {device}')
    
        #define unet
        net = UNet(n_channels=3, n_classes=1, bilinear=True)
        net.to(device=device)
    
        logging.info(f'Network:
    '
                     f'	{net.n_channels} input channels
    '
                     f'	{net.n_classes} output channels (classes)
    '
                     f'	{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
    
        #写network时便于调试网络结构
        # x=torch.ones(1,3,256,256)
        # tmp=net(x)
    
        #finetune
        if args.load:
            net.load_state_dict(
                torch.load(args.load, map_location=device)
            )
            logging.info(f'Model loaded from {args.load}')
    
        try:
            train(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val/100)
        except KeyboardInterrupt:
            torch.save(net.state_dict(), 'INTERRUPTED.pth')
            logging.info('Saved interrupt')
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0)
    

      

  • 相关阅读:
    golang 二进制转十进制实现方式
    mysql select column default value if is null
    MySQL 忘记root密码解决方法,基于Ubuntu 14.10
    ecshop 修改记录20150710
    mac os x 10.9.3 升级到10.10.4 记录
    View requires API level 14 (current min is 8): <GridLayout>
    android开发过程中遇到的坑
    如何为网站设置站点图标
    seq 显示00 01的格式
    lvs realserver 配置VIP
  • 原文地址:https://www.cnblogs.com/liutianrui1/p/13915324.html
Copyright © 2011-2022 走看看