对于分割网络,如果当成一个黑箱就是:输入一个3x1024x1024 输出4x1024x1024。
我没有使用二分类,直接使用了四分类。
分类网络使用了SegNet,没有加载预训练模型,参数也是默认初始化。为了加快训练,1024输入进网络后直接通过
pooling缩小到256的尺寸,等到输出层,直接使用bilinear放大4倍,相当于直接在256的尺寸上训练。
import os import urllib import torch import torch.nn as nn import torch.nn.functional as F #import torch.utils.model_zoo as model_zoo from torchvision import models #https://raw.githubusercontent.com/delta-onera/delta_tb/master/deltatb/networks/net_segnet_bn_relu.py class SegNet_BN_ReLU(nn.Module): # Unet network @staticmethod def weight_init(m): if isinstance(m, nn.Linear): torch.nn.init.kaiming_normal(m.weight.data) def __init__(self, in_channels, out_channels): super(SegNet_BN_ReLU, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.pool = nn.MaxPool2d(2, return_indices=True) self.unpool = nn.MaxUnpool2d(2) self.conv1_1 = nn.Conv2d(in_channels, 64, 3, padding=1) self.conv1_1_bn = nn.BatchNorm2d(64) self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) self.conv1_2_bn = nn.BatchNorm2d(64) self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) self.conv2_1_bn = nn.BatchNorm2d(128) self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) self.conv2_2_bn = nn.BatchNorm2d(128) self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) self.conv3_1_bn = nn.BatchNorm2d(256) self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) self.conv3_2_bn = nn.BatchNorm2d(256) self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) self.conv3_3_bn = nn.BatchNorm2d(256) self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) self.conv4_1_bn = nn.BatchNorm2d(512) self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) self.conv4_2_bn = nn.BatchNorm2d(512) self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) self.conv4_3_bn = nn.BatchNorm2d(512) self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) self.conv5_1_bn = nn.BatchNorm2d(512) self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) self.conv5_2_bn = nn.BatchNorm2d(512) self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) self.conv5_3_bn = nn.BatchNorm2d(512) self.conv5_3_D = nn.Conv2d(512, 512, 3, padding=1) self.conv5_3_D_bn = nn.BatchNorm2d(512) self.conv5_2_D = nn.Conv2d(512, 512, 3, padding=1) self.conv5_2_D_bn = nn.BatchNorm2d(512) self.conv5_1_D = nn.Conv2d(512, 512, 3, padding=1) self.conv5_1_D_bn = nn.BatchNorm2d(512) self.conv4_3_D = nn.Conv2d(512, 512, 3, padding=1) self.conv4_3_D_bn = nn.BatchNorm2d(512) self.conv4_2_D = nn.Conv2d(512, 512, 3, padding=1) self.conv4_2_D_bn = nn.BatchNorm2d(512) self.conv4_1_D = nn.Conv2d(512, 256, 3, padding=1) self.conv4_1_D_bn = nn.BatchNorm2d(256) self.conv3_3_D = nn.Conv2d(256, 256, 3, padding=1) self.conv3_3_D_bn = nn.BatchNorm2d(256) self.conv3_2_D = nn.Conv2d(256, 256, 3, padding=1) self.conv3_2_D_bn = nn.BatchNorm2d(256) self.conv3_1_D = nn.Conv2d(256, 128, 3, padding=1) self.conv3_1_D_bn = nn.BatchNorm2d(128) self.conv2_2_D = nn.Conv2d(128, 128, 3, padding=1) self.conv2_2_D_bn = nn.BatchNorm2d(128) self.conv2_1_D = nn.Conv2d(128, 64, 3, padding=1) self.conv2_1_D_bn = nn.BatchNorm2d(64) self.conv1_2_D = nn.Conv2d(64, 64, 3, padding=1) self.conv1_2_D_bn = nn.BatchNorm2d(64) self.conv1_1_D = nn.Conv2d(64, out_channels, 3, padding=1) self.apply(self.weight_init) def forward(self, x): # Encoder block 1 x =F.avg_pool2d(x,4) #print(x.shape) x = self.conv1_1_bn(F.relu(self.conv1_1(x))) x1 = self.conv1_2_bn(F.relu(self.conv1_2(x))) size1 = x.size() x, mask1 = self.pool(x1) # Encoder block 2 x = self.conv2_1_bn(F.relu(self.conv2_1(x))) #x = self.drop2_1(x) x2 = self.conv2_2_bn(F.relu(self.conv2_2(x))) size2 = x.size() x, mask2 = self.pool(x2) # Encoder block 3 x = self.conv3_1_bn(F.relu(self.conv3_1(x))) x = self.conv3_2_bn(F.relu(self.conv3_2(x))) x3 = self.conv3_3_bn(F.relu(self.conv3_3(x))) size3 = x.size() x, mask3 = self.pool(x3) # Encoder block 4 x = self.conv4_1_bn(F.relu(self.conv4_1(x))) x = self.conv4_2_bn(F.relu(self.conv4_2(x))) x4 = self.conv4_3_bn(F.relu(self.conv4_3(x))) size4 = x.size() x, mask4 = self.pool(x4) # Encoder block 5 x = self.conv5_1_bn(F.relu(self.conv5_1(x))) x = self.conv5_2_bn(F.relu(self.conv5_2(x))) x = self.conv5_3_bn(F.relu(self.conv5_3(x))) size5 = x.size() x, mask5 = self.pool(x) # Decoder block 5 x = self.unpool(x, mask5, output_size = size5) x = self.conv5_3_D_bn(F.relu(self.conv5_3_D(x))) x = self.conv5_2_D_bn(F.relu(self.conv5_2_D(x))) x = self.conv5_1_D_bn(F.relu(self.conv5_1_D(x))) # Decoder block 4 x = self.unpool(x, mask4, output_size = size4) x = self.conv4_3_D_bn(F.relu(self.conv4_3_D(x))) x = self.conv4_2_D_bn(F.relu(self.conv4_2_D(x))) x = self.conv4_1_D_bn(F.relu(self.conv4_1_D(x))) # Decoder block 3 x = self.unpool(x, mask3, output_size = size3) x = self.conv3_3_D_bn(F.relu(self.conv3_3_D(x))) x = self.conv3_2_D_bn(F.relu(self.conv3_2_D(x))) x = self.conv3_1_D_bn(F.relu(self.conv3_1_D(x))) # Decoder block 2 x = self.unpool(x, mask2, output_size = size2) x = self.conv2_2_D_bn(F.relu(self.conv2_2_D(x))) x = self.conv2_1_D_bn(F.relu(self.conv2_1_D(x))) # Decoder block 1 x = self.unpool(x, mask1, output_size = size1) x = self.conv1_2_D_bn(F.relu(self.conv1_2_D(x))) x = self.conv1_1_D(x) #print(x.shape) return F.interpolate(x,mode='bilinear',scale_factor=4) def load_pretrained_weights(self): #vgg16_weights = model_zoo.load_url("https://download.pytorch.org/models/vgg16_bn-6c64b313.pth") vgg16_weights=models.vgg16_bn(True).state_dict() count_vgg = 0 count_this = 0 vggkeys = list(vgg16_weights.keys()) thiskeys = list(self.state_dict().keys()) corresp_map = [] while(True): vggkey = vggkeys[count_vgg] thiskey = thiskeys[count_this] if "classifier" in vggkey: break while vggkey.split(".")[-1] not in thiskey: count_this += 1 thiskey = thiskeys[count_this] corresp_map.append([vggkey, thiskey]) count_vgg+=1 count_this += 1 mapped_weights = self.state_dict() for k_vgg, k_segnet in corresp_map: if (self.in_channels != 3) and "features" in k_vgg and "conv1_1." not in k_segnet: mapped_weights[k_segnet] = vgg16_weights[k_vgg] elif (self.in_channels == 3) and "features" in k_vgg: mapped_weights[k_segnet] = vgg16_weights[k_vgg] try: self.load_state_dict(mapped_weights) print("Loaded VGG-16 weights in Segnet !") except: print("Error VGG-16 weights in Segnet !") raise def load_from_filename(self, model_path): """Load weights from filename.""" th = torch.load(model_path) # load the weigths self.load_state_dict(th) def segnet_bn_relu(in_channels, out_channels, pretrained=False, **kwargs): """Constructs a ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = SegNet_BN_ReLU(in_channels, out_channels) if pretrained: model.load_pretrained_weights() return model if __name__=='__main__': net=segnet_bn_relu(3,4,False) print(net) x=torch.rand((1,3,1024,1024)) print(net.forward(x).shape)
训练网络的代码:
import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from farmdataset import FarmDataset from segnet import segnet_bn_relu as Unet import time from PIL import Image def train(args, model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) #print(target.shape) optimizer.zero_grad() output = model(data) #print('output size',output.size(),output) output = F.log_softmax(output, dim=1) loss=nn.NLLLoss2d(weight=torch.Tensor([0.1,0.5,0.5,0.2]).to('cuda'))(output,target) loss.backward() optimizer.step() #time.sleep(0.6)#make gpu sleep if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) if epoch%2==0: imgd=output.detach()[0,:,:,:].cpu() img=torch.argmax(imgd,0).byte().numpy() imgx=Image.fromarray(img).convert('L') imgxx=Image.fromarray(target.detach()[0,:,:].cpu().byte().numpy()*255).convert('L') imgx.save("./tmp/predict{}.bmp".format(epoch)) imgxx.save('./tmp/real{}.bmp'.format(epoch)) def test(args, model, device, testdataset,issave=False): model.eval() test_loss = 0 correct = 0 evalid=[i+7 for i in range(0,2100,15)] maxbatch=len(evalid) with torch.no_grad(): for idx in evalid: data, target=testdataset[idx] data, target = data.unsqueeze(0).to(device), target.unsqueeze(0).to(device) #print(target.shape) target=target[:,:1472,:1472] output = model(data[:,:,:1472,:1472]) output = F.log_softmax(output, dim=1) loss=nn.NLLLoss2d().to('cuda')(output,target) test_loss+=loss r=torch.argmax(output[0],0).byte() tg=target.byte().squeeze(0) tmp=0 count=0 for i in range(1,4): mp=r==i tr=tg==i tp=mp*tr==1 t=(mp+tr-tp).sum().item() if t==0: continue else: tmp+=tp.sum().item()/t count+=1 if count>0: correct+=tmp/count if issave: Image.fromarray(r.cpu().numpy()).save('predict.png') Image.fromarray(tg.cpu().numpy()).save('target.png') input() print('Test Loss is {:.6f}, mean precision is: {:.4f}%'.format(test_loss/maxbatch,correct)) def main(): # Training settings parser = argparse.ArgumentParser(description='Scratch segmentation Example') parser.add_argument('--batch-size', type=int, default=8, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=8, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=30, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") print('my device is :',device) kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( FarmDataset(istrain=True),batch_size=args.batch_size, shuffle=True,drop_last=True, **kwargs) startepoch=0 model =torch.load('./tmp/model{}'.format(startepoch)) if startepoch else Unet(3,4).to(device) args.epochs=50 optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) for epoch in range(startepoch, args.epochs + 1): train(args, model, device, train_loader, optimizer, epoch) if epoch %3==0: print(epoch) test(args, model, device, FarmDataset(istrain=True,isaug=False),issave=False) torch.save(model,'./tmp/model{}'.format(epoch)) if __name__ == '__main__': main()
训练代码每隔三轮,评测一次训练精度,测试数据仍然使用训练数据,只是抽样了。可以根据该精度选择使用什么时刻的网络作为预测节点。
训练精度可以达到0.4,但是这时候貌似过学习了。提交结果并不好。
到现在感觉,只要改进一点,分数就会高一点。如果要继续提高成绩,感觉可以从以下几个方面改进:
样本的不均衡
损失函数
模型结构的设计 可以参考PSP,UNET,deeplab,或者GAN的pix2pix。
总之,感觉只要进行一点改进,功夫就不会白费。
整个从数据切割,数据集准备,数据增强,预测结果保存,深度分割网络 和网络训练,全部代码到此分享完毕,
做完这些你的结果就能到0.2以上。 也是折腾了好几天才到现在,希望这能成为一个基线,看到更精彩的模型思路。
(完)