zoukankan      html  css  js  c++  java
  • U-Net网络的Pytorch实现

    1.文章原文地址

    U-Net: Convolutional Networks for Biomedical Image Segmentation

    2.文章摘要

    普遍认为成功训练深度神经网络需要大量标注的训练数据。在本文中,我们提出了一个网络结构,以及使用数据增强的策略来训练网络使得可用的标注样本更加有效的被使用。这个网络是由一个捕捉上下文信息的收缩部分和与之相对称的放大部分,后者能够准确的定位。我们的结果展示了这个网络可以进行端到端的训练,使用非常少的数据就可以达到非常好的结果,并且超过了当前的最佳方法(滑动窗网络)在ISBII挑战赛上电子显微镜下神经结构的分割的结果。利用透射光显微镜图像使用相同网络进行训练,我们大幅度的赢得了2015年的ISBI细胞追踪挑战赛。而且,这个网络非常快,在一个当前的GPU上,分割一个512x512的图像所花费的时间少于一秒。完整的代码以及训练好的网络可见(基于Caffe)http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.

    3.网络结构

    4.Pytorch实现

      1 import torch
      2 import torch.nn as nn
      3 import torch.nn.functional as F
      4 from torchsummary import summary
      5 
      6 
      7 class unetConv2(nn.Module):
      8     def __init__(self,in_size,out_size,is_batchnorm):
      9         super(unetConv2,self).__init__()
     10 
     11         if is_batchnorm:
     12             self.conv1=nn.Sequential(
     13                 nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
     14                 nn.BatchNorm2d(out_size),
     15                 nn.ReLU(inplace=True),
     16             )
     17             self.conv2=nn.Sequential(
     18                 nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
     19                 nn.BatchNorm2d(out_size),
     20                 nn.ReLU(inplace=True),
     21             )
     22         else:
     23             self.conv1=nn.Sequential(
     24                 nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
     25                 nn.ReLU(inplace=True),
     26             )
     27             self.conv2=nn.Sequential(
     28                 nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
     29                 nn.ReLU(inplace=True)
     30             )
     31     def forward(self, inputs):
     32         outputs=self.conv1(inputs)
     33         outputs=self.conv2(outputs)
     34 
     35         return outputs
     36 
     37 class unetUp(nn.Module):
     38     def __init__(self,in_size,out_size,is_deconv):
     39         super(unetUp,self).__init__()
     40         self.conv=unetConv2(in_size,out_size,False)
     41         if is_deconv:
     42             self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2)
     43         else:
     44             self.up=nn.UpsamplingBilinear2d(scale_factor=2)
     45 
     46     def forward(self, inputs1,inputs2):
     47         outputs2=self.up(inputs2)
     48         offset=outputs2.size()[2]-inputs1.size()[2]
     49         padding=2*[offset//2,offset//2]
     50         outputs1=F.pad(inputs1,padding)     #padding is negative, size become smaller
     51 
     52         return self.conv(torch.cat([outputs1,outputs2],1))
     53 
     54 class unet(nn.Module):
     55     def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
     56         super(unet,self).__init__()
     57         self.is_deconv=is_deconv
     58         self.in_channels=in_channels
     59         self.is_batchnorm=is_batchnorm
     60         self.feature_scale=feature_scale
     61 
     62         filters=[64,128,256,512,1024]
     63         filters=[int(x/self.feature_scale) for x in filters]
     64 
     65         #downsample
     66         self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm)
     67         self.maxpool1=nn.MaxPool2d(kernel_size=2)
     68 
     69         self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
     70         self.maxpool2=nn.MaxPool2d(kernel_size=2)
     71 
     72         self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
     73         self.maxpool3=nn.MaxPool2d(kernel_size=2)
     74 
     75         self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
     76         self.maxpool4=nn.MaxPool2d(kernel_size=2)
     77 
     78         self.center=unetConv2(filters[3],filters[4],self.is_batchnorm)
     79 
     80         #umsampling
     81         self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv)
     82         self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
     83         self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
     84         self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv)
     85 
     86         #final conv (without and concat)
     87         self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1)
     88 
     89     def forward(self, inputs):
     90         conv1=self.conv1(inputs)
     91         maxpool1=self.maxpool1(conv1)
     92 
     93         conv2=self.conv2(maxpool1)
     94         maxpool2=self.maxpool2(conv2)
     95 
     96         conv3=self.conv3(maxpool2)
     97         maxpool3=self.maxpool3(conv3)
     98 
     99         conv4=self.conv4(maxpool3)
    100         maxpool4=self.maxpool4(conv4)
    101 
    102         center=self.center(maxpool4)
    103         up4=self.up_concat4(conv4,center)
    104         up3=self.up_concat3(conv3,up4)
    105         up2=self.up_concat2(conv2,up3)
    106         up1=self.up_concat1(conv1,up2)
    107 
    108         final=self.final(up1)
    109 
    110         return final
    111 
    112 if __name__=="__main__":
    113     model=unet(feature_scale=1)
    114     print(summary(model,(3,572,572)))
     1 ----------------------------------------------------------------
     2         Layer (type)               Output Shape         Param #
     3 ================================================================
     4             Conv2d-1         [-1, 64, 570, 570]           1,792
     5        BatchNorm2d-2         [-1, 64, 570, 570]             128
     6               ReLU-3         [-1, 64, 570, 570]               0
     7             Conv2d-4         [-1, 64, 568, 568]          36,928
     8        BatchNorm2d-5         [-1, 64, 568, 568]             128
     9               ReLU-6         [-1, 64, 568, 568]               0
    10          unetConv2-7         [-1, 64, 568, 568]               0
    11          MaxPool2d-8         [-1, 64, 284, 284]               0
    12             Conv2d-9        [-1, 128, 282, 282]          73,856
    13       BatchNorm2d-10        [-1, 128, 282, 282]             256
    14              ReLU-11        [-1, 128, 282, 282]               0
    15            Conv2d-12        [-1, 128, 280, 280]         147,584
    16       BatchNorm2d-13        [-1, 128, 280, 280]             256
    17              ReLU-14        [-1, 128, 280, 280]               0
    18         unetConv2-15        [-1, 128, 280, 280]               0
    19         MaxPool2d-16        [-1, 128, 140, 140]               0
    20            Conv2d-17        [-1, 256, 138, 138]         295,168
    21       BatchNorm2d-18        [-1, 256, 138, 138]             512
    22              ReLU-19        [-1, 256, 138, 138]               0
    23            Conv2d-20        [-1, 256, 136, 136]         590,080
    24       BatchNorm2d-21        [-1, 256, 136, 136]             512
    25              ReLU-22        [-1, 256, 136, 136]               0
    26         unetConv2-23        [-1, 256, 136, 136]               0
    27         MaxPool2d-24          [-1, 256, 68, 68]               0
    28            Conv2d-25          [-1, 512, 66, 66]       1,180,160
    29       BatchNorm2d-26          [-1, 512, 66, 66]           1,024
    30              ReLU-27          [-1, 512, 66, 66]               0
    31            Conv2d-28          [-1, 512, 64, 64]       2,359,808
    32       BatchNorm2d-29          [-1, 512, 64, 64]           1,024
    33              ReLU-30          [-1, 512, 64, 64]               0
    34         unetConv2-31          [-1, 512, 64, 64]               0
    35         MaxPool2d-32          [-1, 512, 32, 32]               0
    36            Conv2d-33         [-1, 1024, 30, 30]       4,719,616
    37       BatchNorm2d-34         [-1, 1024, 30, 30]           2,048
    38              ReLU-35         [-1, 1024, 30, 30]               0
    39            Conv2d-36         [-1, 1024, 28, 28]       9,438,208
    40       BatchNorm2d-37         [-1, 1024, 28, 28]           2,048
    41              ReLU-38         [-1, 1024, 28, 28]               0
    42         unetConv2-39         [-1, 1024, 28, 28]               0
    43   ConvTranspose2d-40          [-1, 512, 56, 56]       2,097,664
    44            Conv2d-41          [-1, 512, 54, 54]       4,719,104
    45              ReLU-42          [-1, 512, 54, 54]               0
    46            Conv2d-43          [-1, 512, 52, 52]       2,359,808
    47              ReLU-44          [-1, 512, 52, 52]               0
    48         unetConv2-45          [-1, 512, 52, 52]               0
    49            unetUp-46          [-1, 512, 52, 52]               0
    50   ConvTranspose2d-47        [-1, 256, 104, 104]         524,544
    51            Conv2d-48        [-1, 256, 102, 102]       1,179,904
    52              ReLU-49        [-1, 256, 102, 102]               0
    53            Conv2d-50        [-1, 256, 100, 100]         590,080
    54              ReLU-51        [-1, 256, 100, 100]               0
    55         unetConv2-52        [-1, 256, 100, 100]               0
    56            unetUp-53        [-1, 256, 100, 100]               0
    57   ConvTranspose2d-54        [-1, 128, 200, 200]         131,200
    58            Conv2d-55        [-1, 128, 198, 198]         295,040
    59              ReLU-56        [-1, 128, 198, 198]               0
    60            Conv2d-57        [-1, 128, 196, 196]         147,584
    61              ReLU-58        [-1, 128, 196, 196]               0
    62         unetConv2-59        [-1, 128, 196, 196]               0
    63            unetUp-60        [-1, 128, 196, 196]               0
    64   ConvTranspose2d-61         [-1, 64, 392, 392]          32,832
    65            Conv2d-62         [-1, 64, 390, 390]          73,792
    66              ReLU-63         [-1, 64, 390, 390]               0
    67            Conv2d-64         [-1, 64, 388, 388]          36,928
    68              ReLU-65         [-1, 64, 388, 388]               0
    69         unetConv2-66         [-1, 64, 388, 388]               0
    70            unetUp-67         [-1, 64, 388, 388]               0
    71            Conv2d-68         [-1, 21, 388, 388]           1,365
    72 ================================================================
    73 Total params: 31,040,981
    74 Trainable params: 31,040,981
    75 Non-trainable params: 0
    76 ----------------------------------------------------------------
    77 Input size (MB): 3.74
    78 Forward/backward pass size (MB): 3158.15
    79 Params size (MB): 118.41
    80 Estimated Total Size (MB): 3280.31

    参考

    https://github.com/meetshah1995/pytorch-semseg

  • 相关阅读:
    phpajax高级篇
    一天学会ajax (php环境)
    php生成静态文件的方法
    MongoDB查询文档
    MongoDB删除文档
    MongoDB索引管理
    MongoDB插入文档
    MongoDB排序记录
    MongoDB 更新文档
    mongoDB 固定集合(capped collection)
  • 原文地址:https://www.cnblogs.com/ys99/p/10889695.html
Copyright © 2011-2022 走看看