zoukankan      html  css  js  c++  java
  • Pytorch 分割模型构建和训练【直播】2019 年县域农业大脑AI挑战赛---(四)模型构建和网络训练

    对于分割网络,如果当成一个黑箱就是:输入一个3x1024x1024 输出4x1024x1024。

    我没有使用二分类,直接使用了四分类。

    分类网络使用了SegNet,没有加载预训练模型,参数也是默认初始化。为了加快训练,1024输入进网络后直接通过

    pooling缩小到256的尺寸,等到输出层,直接使用bilinear放大4倍,相当于直接在256的尺寸上训练。

    import os
    import urllib
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    #import torch.utils.model_zoo as model_zoo
    from torchvision import models
    #https://raw.githubusercontent.com/delta-onera/delta_tb/master/deltatb/networks/net_segnet_bn_relu.py
    class SegNet_BN_ReLU(nn.Module):
        # Unet network
        @staticmethod
        def weight_init(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.kaiming_normal(m.weight.data)
        
        def __init__(self, in_channels, out_channels):
            super(SegNet_BN_ReLU, self).__init__()
    
            self.in_channels = in_channels
            self.out_channels = out_channels
    
            self.pool = nn.MaxPool2d(2, return_indices=True)
            self.unpool = nn.MaxUnpool2d(2)
            
            self.conv1_1 = nn.Conv2d(in_channels, 64, 3, padding=1)
            self.conv1_1_bn = nn.BatchNorm2d(64)
            self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
            self.conv1_2_bn = nn.BatchNorm2d(64)
            
            self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
            self.conv2_1_bn = nn.BatchNorm2d(128)
            self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
            self.conv2_2_bn = nn.BatchNorm2d(128)
            
            self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
            self.conv3_1_bn = nn.BatchNorm2d(256)
            self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
            self.conv3_2_bn = nn.BatchNorm2d(256)
            self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
            self.conv3_3_bn = nn.BatchNorm2d(256)
            
            self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
            self.conv4_1_bn = nn.BatchNorm2d(512)
            self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
            self.conv4_2_bn = nn.BatchNorm2d(512)
            self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
            self.conv4_3_bn = nn.BatchNorm2d(512)
            
            self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
            self.conv5_1_bn = nn.BatchNorm2d(512)
            self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
            self.conv5_2_bn = nn.BatchNorm2d(512)
            self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
            self.conv5_3_bn = nn.BatchNorm2d(512)
            
            self.conv5_3_D = nn.Conv2d(512, 512, 3, padding=1)
            self.conv5_3_D_bn = nn.BatchNorm2d(512)
            self.conv5_2_D = nn.Conv2d(512, 512, 3, padding=1)
            self.conv5_2_D_bn = nn.BatchNorm2d(512)
            self.conv5_1_D = nn.Conv2d(512, 512, 3, padding=1)
            self.conv5_1_D_bn = nn.BatchNorm2d(512)
            
            self.conv4_3_D = nn.Conv2d(512, 512, 3, padding=1)
            self.conv4_3_D_bn = nn.BatchNorm2d(512)
            self.conv4_2_D = nn.Conv2d(512, 512, 3, padding=1)
            self.conv4_2_D_bn = nn.BatchNorm2d(512)
            self.conv4_1_D = nn.Conv2d(512, 256, 3, padding=1)
            self.conv4_1_D_bn = nn.BatchNorm2d(256)
            
            self.conv3_3_D = nn.Conv2d(256, 256, 3, padding=1)
            self.conv3_3_D_bn = nn.BatchNorm2d(256)
            self.conv3_2_D = nn.Conv2d(256, 256, 3, padding=1)
            self.conv3_2_D_bn = nn.BatchNorm2d(256)
            self.conv3_1_D = nn.Conv2d(256, 128, 3, padding=1)
            self.conv3_1_D_bn = nn.BatchNorm2d(128)
            
            self.conv2_2_D = nn.Conv2d(128, 128, 3, padding=1)
            self.conv2_2_D_bn = nn.BatchNorm2d(128)
            self.conv2_1_D = nn.Conv2d(128, 64, 3, padding=1)
            self.conv2_1_D_bn = nn.BatchNorm2d(64)
            
            self.conv1_2_D = nn.Conv2d(64, 64, 3, padding=1)
            self.conv1_2_D_bn = nn.BatchNorm2d(64)
            self.conv1_1_D = nn.Conv2d(64, out_channels, 3, padding=1)
            
            self.apply(self.weight_init)
            
        def forward(self, x):
            # Encoder block 1
            x =F.avg_pool2d(x,4)
            #print(x.shape)
            x = self.conv1_1_bn(F.relu(self.conv1_1(x)))
            x1 = self.conv1_2_bn(F.relu(self.conv1_2(x)))
            size1 = x.size()
            x, mask1 = self.pool(x1)
            
            # Encoder block 2
            x = self.conv2_1_bn(F.relu(self.conv2_1(x)))
            #x = self.drop2_1(x)
            x2 = self.conv2_2_bn(F.relu(self.conv2_2(x)))
            size2 = x.size()
            x, mask2 = self.pool(x2)
            
            # Encoder block 3
            x = self.conv3_1_bn(F.relu(self.conv3_1(x)))
            x = self.conv3_2_bn(F.relu(self.conv3_2(x)))
            x3 = self.conv3_3_bn(F.relu(self.conv3_3(x)))
            size3 = x.size()
            x, mask3 = self.pool(x3)
            
            # Encoder block 4
            x = self.conv4_1_bn(F.relu(self.conv4_1(x)))
            x = self.conv4_2_bn(F.relu(self.conv4_2(x)))
            x4 = self.conv4_3_bn(F.relu(self.conv4_3(x)))
            size4 = x.size()
            x, mask4 = self.pool(x4)
            
            # Encoder block 5
            x = self.conv5_1_bn(F.relu(self.conv5_1(x)))
            x = self.conv5_2_bn(F.relu(self.conv5_2(x)))
            x = self.conv5_3_bn(F.relu(self.conv5_3(x)))
            size5 = x.size()
            x, mask5 = self.pool(x)
            
            # Decoder block 5
            x = self.unpool(x, mask5, output_size = size5)
            x = self.conv5_3_D_bn(F.relu(self.conv5_3_D(x)))
            x = self.conv5_2_D_bn(F.relu(self.conv5_2_D(x)))
            x = self.conv5_1_D_bn(F.relu(self.conv5_1_D(x)))
            
            # Decoder block 4
            x = self.unpool(x, mask4, output_size = size4)
            x = self.conv4_3_D_bn(F.relu(self.conv4_3_D(x)))
            x = self.conv4_2_D_bn(F.relu(self.conv4_2_D(x)))
            x = self.conv4_1_D_bn(F.relu(self.conv4_1_D(x)))
            
            # Decoder block 3
            x = self.unpool(x, mask3, output_size = size3)
            x = self.conv3_3_D_bn(F.relu(self.conv3_3_D(x)))
            x = self.conv3_2_D_bn(F.relu(self.conv3_2_D(x)))
            x = self.conv3_1_D_bn(F.relu(self.conv3_1_D(x)))
            
            # Decoder block 2
            x = self.unpool(x, mask2, output_size = size2)
            x = self.conv2_2_D_bn(F.relu(self.conv2_2_D(x)))
            x = self.conv2_1_D_bn(F.relu(self.conv2_1_D(x)))
            
            # Decoder block 1
            x = self.unpool(x, mask1, output_size = size1)
            x = self.conv1_2_D_bn(F.relu(self.conv1_2_D(x)))
            x = self.conv1_1_D(x)
            #print(x.shape)
            return F.interpolate(x,mode='bilinear',scale_factor=4)  
    
        def load_pretrained_weights(self):
    
            #vgg16_weights = model_zoo.load_url("https://download.pytorch.org/models/vgg16_bn-6c64b313.pth")
            vgg16_weights=models.vgg16_bn(True).state_dict()
            count_vgg = 0
            count_this = 0
    
            vggkeys = list(vgg16_weights.keys())
            thiskeys  = list(self.state_dict().keys())
    
            corresp_map = []
    
            while(True):
                vggkey = vggkeys[count_vgg]
                thiskey = thiskeys[count_this]
    
                if "classifier" in vggkey:
                    break
                
                while vggkey.split(".")[-1] not in thiskey:
                    count_this += 1
                    thiskey = thiskeys[count_this]
    
    
                corresp_map.append([vggkey, thiskey])
                count_vgg+=1
                count_this += 1
    
            mapped_weights = self.state_dict()
            for k_vgg, k_segnet in corresp_map:
                if (self.in_channels != 3) and "features" in k_vgg and "conv1_1." not in k_segnet:
                    mapped_weights[k_segnet] = vgg16_weights[k_vgg]
                elif (self.in_channels == 3) and "features" in k_vgg:
                    mapped_weights[k_segnet] = vgg16_weights[k_vgg]
    
            try:
                self.load_state_dict(mapped_weights)
                print("Loaded VGG-16 weights in Segnet !")
            except:
                print("Error VGG-16 weights in Segnet !")
                raise
        
        def load_from_filename(self, model_path):
            """Load weights from filename."""
            th = torch.load(model_path)  # load the weigths
            self.load_state_dict(th)
    
    
    def segnet_bn_relu(in_channels, out_channels, pretrained=False, **kwargs):
        """Constructs a ResNet-34 model.
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = SegNet_BN_ReLU(in_channels, out_channels)
        if pretrained:
            model.load_pretrained_weights()
        return model
    
    if __name__=='__main__':
        net=segnet_bn_relu(3,4,False)
        print(net)
        x=torch.rand((1,3,1024,1024))
        print(net.forward(x).shape)
    

      训练网络的代码:

    import argparse
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from farmdataset import FarmDataset
    
    from segnet import segnet_bn_relu as Unet
    
    import time
    
    from PIL import Image
    
    
    def train(args, model, device, train_loader, optimizer, epoch):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            #print(target.shape)
            optimizer.zero_grad()
            output = model(data)
            #print('output size',output.size(),output)
    
            output = F.log_softmax(output, dim=1) 
            loss=nn.NLLLoss2d(weight=torch.Tensor([0.1,0.5,0.5,0.2]).to('cuda'))(output,target)
            loss.backward()
           
            optimizer.step()
    
            #time.sleep(0.6)#make gpu sleep
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
        if epoch%2==0:
            imgd=output.detach()[0,:,:,:].cpu()
            img=torch.argmax(imgd,0).byte().numpy()
            imgx=Image.fromarray(img).convert('L')
            imgxx=Image.fromarray(target.detach()[0,:,:].cpu().byte().numpy()*255).convert('L')
            imgx.save("./tmp/predict{}.bmp".format(epoch))
            imgxx.save('./tmp/real{}.bmp'.format(epoch))
    
    def test(args, model, device, testdataset,issave=False):
        model.eval()
        test_loss = 0
        correct = 0
        evalid=[i+7 for i in range(0,2100,15)]
        maxbatch=len(evalid)
        with torch.no_grad():
            for idx in evalid:
                data, target=testdataset[idx]
                data, target = data.unsqueeze(0).to(device), target.unsqueeze(0).to(device)
                #print(target.shape)
                target=target[:,:1472,:1472]
                output = model(data[:,:,:1472,:1472])
                output = F.log_softmax(output, dim=1) 
                loss=nn.NLLLoss2d().to('cuda')(output,target)
                test_loss+=loss
                
                r=torch.argmax(output[0],0).byte()
     
                tg=target.byte().squeeze(0)
                tmp=0
                count=0
                for i in range(1,4):
                    mp=r==i
                    tr=tg==i
                    tp=mp*tr==1
                    t=(mp+tr-tp).sum().item()
                    if t==0:
                        continue
                    else:
                        tmp+=tp.sum().item()/t
                        count+=1
                if count>0:
                    correct+=tmp/count
               
                
                if issave:
                    Image.fromarray(r.cpu().numpy()).save('predict.png')
                    Image.fromarray(tg.cpu().numpy()).save('target.png')
                    input()
                    
        print('Test Loss is {:.6f}, mean precision is: {:.4f}%'.format(test_loss/maxbatch,correct))
    
    
    def main():
        # Training settings
        parser = argparse.ArgumentParser(description='Scratch segmentation Example')
        parser.add_argument('--batch-size', type=int, default=8, metavar='N',
                            help='input batch size for training (default: 64)')
        parser.add_argument('--test-batch-size', type=int, default=8, metavar='N',
                            help='input batch size for testing (default: 1000)')
        parser.add_argument('--epochs', type=int, default=30, metavar='N',
                            help='number of epochs to train (default: 10)')
        parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                            help='learning rate (default: 0.01)')
        parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                            help='SGD momentum (default: 0.5)')
        parser.add_argument('--no-cuda', action='store_true', default=False,
                            help='disables CUDA training')
        parser.add_argument('--seed', type=int, default=1, metavar='S',
                            help='random seed (default: 1)')
        parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                            help='how many batches to wait before logging training status')
        args = parser.parse_args()
        use_cuda = not args.no_cuda and torch.cuda.is_available()
    
        torch.manual_seed(args.seed)
    
        device = torch.device("cuda" if use_cuda else "cpu")
        print('my device is :',device)
    
        kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
        train_loader = torch.utils.data.DataLoader(     FarmDataset(istrain=True),batch_size=args.batch_size, shuffle=True,drop_last=True, **kwargs)
        
        startepoch=0
        model =torch.load('./tmp/model{}'.format(startepoch))  if startepoch else Unet(3,4).to(device)  
        args.epochs=50
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    
        for epoch in range(startepoch, args.epochs + 1):
            train(args, model, device, train_loader, optimizer, epoch)
            if epoch %3==0:
                print(epoch)
                test(args, model, device, FarmDataset(istrain=True,isaug=False),issave=False)
                torch.save(model,'./tmp/model{}'.format(epoch))
        
    if __name__ == '__main__':
        main()
    

      训练代码每隔三轮,评测一次训练精度,测试数据仍然使用训练数据,只是抽样了。可以根据该精度选择使用什么时刻的网络作为预测节点。

    训练精度可以达到0.4,但是这时候貌似过学习了。提交结果并不好。

    到现在感觉,只要改进一点,分数就会高一点。如果要继续提高成绩,感觉可以从以下几个方面改进:

    样本的不均衡

    损失函数

    模型结构的设计 可以参考PSP,UNET,deeplab,或者GAN的pix2pix。

    总之,感觉只要进行一点改进,功夫就不会白费。

    整个从数据切割,数据集准备,数据增强,预测结果保存,深度分割网络 和网络训练,全部代码到此分享完毕,

    做完这些你的结果就能到0.2以上。 也是折腾了好几天才到现在,希望这能成为一个基线,看到更精彩的模型思路。

    (完)

     

  • 相关阅读:
    Luogu P4071 [SDOI2016]排列计数
    CF 961E Tufurama
    Luogu P2057 [SHOI2007]善意的投票
    Luogu P2756 飞行员配对方案问题
    POJ2151
    POJ 3349&&3274&&2151&&1840&&2002&&2503
    POJ 2388&&2299
    EZ 2018 03 30 NOIP2018 模拟赛(六)
    POJ 1459&&3436
    BZOJ 1001: [BeiJing2006]狼抓兔子
  • 原文地址:https://www.cnblogs.com/yjphhw/p/11097414.html
Copyright © 2011-2022 走看看