zoukankan      html  css  js  c++  java
  • 【655】Res-U-Net 详解说明

    [1] UNet with ResBlock for Semantic Segmentation

    [2] github - UNet-with-ResBlock/resnet34_unet_model.py【上面对应的代码】

    [3] github - ResUNet【运行显示OOM,内存不够】

      结构图如下所示:

      每一个 block 里面都有一个残差连接的部分。

      代码实现【1】:二分类,最后一层用 sigmoid【基于 reference[2] 的代码】

    import numpy as np
    from keras.backend import int_shape
    from keras.models import Model
    from keras.layers import Conv2D, Conv3D, MaxPooling2D, MaxPooling3D, UpSampling2D, UpSampling3D, Add, BatchNormalization, Input, Activation, Lambda, Concatenate
    
    
    def res_unet(filter_root, depth, n_class=2, input_size=(256, 256, 1), activation='relu', batch_norm=True, final_activation='softmax'):
        """
        Build UNet model with ResBlock.
        Args:
            filter_root (int): Number of filters to start with in first convolution.
            depth (int): How deep to go in UNet i.e. how many down and up sampling you want to do in the model. 
                        Filter root and image size should be multiple of 2^depth.
            n_class (int, optional): How many classes in the output layer. Defaults to 2.
            input_size (tuple, optional): Input image size. Defaults to (256, 256, 1).
            activation (str, optional): activation to use in each convolution. Defaults to 'relu'.
            batch_norm (bool, optional): To use Batch normaliztion or not. Defaults to True.
            final_activation (str, optional): activation for output layer. Defaults to 'softmax'.
        Returns:
            obj: keras model object
        """
        inputs = Input(input_size)
        x = inputs
        # Dictionary for long connections
        long_connection_store = {}
    
        if len(input_size) == 3:
            Conv = Conv2D
            MaxPooling = MaxPooling2D
            UpSampling = UpSampling2D
        elif len(input_size) == 4:
            Conv = Conv3D
            MaxPooling = MaxPooling3D
            UpSampling = UpSampling3D
    
        # Down sampling
        for i in range(depth):
            out_channel = 2**i * filter_root
    
            # Residual/Skip connection
            res = Conv(out_channel, kernel_size=1, padding='same', use_bias=False, name="Identity{}_1".format(i))(x)
    
            # First Conv Block with Conv, BN and activation
            conv1 = Conv(out_channel, kernel_size=3, padding='same', name="Conv{}_1".format(i))(x)
            if batch_norm:
                conv1 = BatchNormalization(name="BN{}_1".format(i))(conv1)
            act1 = Activation(activation, name="Act{}_1".format(i))(conv1)
    
            # Second Conv block with Conv and BN only
            conv2 = Conv(out_channel, kernel_size=3, padding='same', name="Conv{}_2".format(i))(act1)
            if batch_norm:
                conv2 = BatchNormalization(name="BN{}_2".format(i))(conv2)
    
            resconnection = Add(name="Add{}_1".format(i))([res, conv2])
    
            act2 = Activation(activation, name="Act{}_2".format(i))(resconnection)
    
            # Max pooling
            if i < depth - 1:
                long_connection_store[str(i)] = act2
                x = MaxPooling(padding='same', name="MaxPooling{}_1".format(i))(act2)
            else:
                x = act2
    
        # Upsampling
        for i in range(depth - 2, -1, -1):
            out_channel = 2**(i) * filter_root
    
            # long connection from down sampling path.
            long_connection = long_connection_store[str(i)]
    
            up1 = UpSampling(name="UpSampling{}_1".format(i))(x)
            up_conv1 = Conv(out_channel, 2, activation='relu', padding='same', name="upConv{}_1".format(i))(up1)
    
            #  Concatenate.
            up_conc = Concatenate(axis=-1, name="upConcatenate{}_1".format(i))([up_conv1, long_connection])
    
            #  Convolutions
            up_conv2 = Conv(out_channel, 3, padding='same', name="upConv{}_1_".format(i))(up_conc)
            if batch_norm:
                up_conv2 = BatchNormalization(name="upBN{}_1".format(i))(up_conv2)
            up_act1 = Activation(activation, name="upAct{}_1".format(i))(up_conv2)
    
            up_conv2 = Conv(out_channel, 3, padding='same', name="upConv{}_2".format(i))(up_act1)
            if batch_norm:
                up_conv2 = BatchNormalization(name="upBN{}_2".format(i))(up_conv2)
    
            # Residual/Skip connection
            res = Conv(out_channel, kernel_size=1, padding='same', use_bias=False, name="upIdentity{}_1".format(i))(up_conc)
    
            resconnection = Add(name="upAdd{}_1".format(i))([res, up_conv2])
    
            x = Activation(activation, name="upAct{}_2".format(i))(resconnection)
    
        # Final convolution
        output = Conv(1, 1, padding='same', activation=final_activation, name='output')(x)
    
        return Model(inputs, outputs=output, name='Res-UNet')
    
    model = res_unet(64, 5, n_class=2, input_size=(512, 512, 3), activation='relu', batch_norm=True, final_activation='sigmoid')
    model.summary() 
    

      代码实现【2】:二分类,最后一层用 sigmoid

    from keras.applications import vgg16
    from keras.models import Model, Sequential
    from keras.layers import Conv2D, UpSampling2D, Input, add, concatenate, Dropout, Activation, BatchNormalization
    from keras.utils.vis_utils import plot_model
    
    def batch_Norm_Activation(x, BN=False): ## To Turn off Batch Normalization, Change BN to False >
        if BN == True:
            x = BatchNormalization()(x)
            x = Activation("relu")(x)
        else:
            x= Activation("relu")(x)
        return x
    
    
    def ResUnet2D(filters, input_height, input_width):
        
    #     encoder
        
        inputs = Input(shape=(input_height, input_width, 3))
        
        conv = Conv2D(filters*1, kernel_size= (3,3), padding= 'same', strides= (1,1))(inputs)
        conv = batch_Norm_Activation(conv)
        conv = Conv2D(filters*1, kernel_size= (3,3), padding= 'same', strides= (1,1))(conv)
        shortcut = Conv2D(filters*1, kernel_size=(1,1), padding='same', strides=(1,1))(inputs)
        shortcut = batch_Norm_Activation(shortcut)
        output1 = add([conv, shortcut])
        
        res1 = batch_Norm_Activation(output1)
        res1 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides= (2,2))(res1)
        res1 = batch_Norm_Activation(res1)
        res1 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides= (1,1))(res1)
        shortcut1 = Conv2D(filters*2, kernel_size= (3,3), padding='same', strides=(2,2))(output1)
        shortcut1 = batch_Norm_Activation(shortcut1)
        output2 = add([shortcut1, res1])
        
        res2 = batch_Norm_Activation(output2)
        res2 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides= (2,2))(res2)
        res2 = batch_Norm_Activation(res2)
        res2 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides= (1,1))(res2)
        shortcut2 = Conv2D(filters*4, kernel_size= (3,3), padding='same', strides=(2,2))(output2)
        shortcut2 = batch_Norm_Activation(shortcut2)
        output3 = add([shortcut2, res2])
        
        res3 = batch_Norm_Activation(output3)
        res3 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides= (2,2))(res3)
        res3 = batch_Norm_Activation(res3)
        res3 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides= (1,1))(res3)
        shortcut3 = Conv2D(filters*8, kernel_size= (3,3), padding='same', strides=(2,2))(output3)
        shortcut3 = batch_Norm_Activation(shortcut3)
        output4 = add([shortcut3, res3])
        
        res4 = batch_Norm_Activation(output4)
        res4 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (2,2))(res4)
        res4 = batch_Norm_Activation(res4)
        res4 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (1,1))(res4)
        shortcut4 = Conv2D(filters*16, kernel_size= (3,3), padding='same', strides=(2,2))(output4)
        shortcut4 = batch_Norm_Activation(shortcut4)
        output5 = add([shortcut4, res4])
        
        #bridge
        conv = batch_Norm_Activation(output5)
        conv = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (1,1))(conv)
        conv = batch_Norm_Activation(conv)
        conv = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides= (1,1))(conv)
        
        #decoder
        
        uconv1 = UpSampling2D((2,2))(conv)
        uconv1 = concatenate([uconv1, output4])
        
        uconv11 = batch_Norm_Activation(uconv1)
        uconv11 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv11)
        uconv11 = batch_Norm_Activation(uconv11)
        uconv11 = Conv2D(filters*16, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv11)
        shortcut5 = Conv2D(filters*16, kernel_size= (3,3), padding='same', strides=(1,1))(uconv1)
        shortcut5 = batch_Norm_Activation(shortcut5)
        output6 = add([uconv11,shortcut5])
        
        uconv2 = UpSampling2D((2,2))(output6)
        uconv2 = concatenate([uconv2, output3])
        
        uconv22 = batch_Norm_Activation(uconv2)
        uconv22 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv22)
        uconv22 = batch_Norm_Activation(uconv22)
        uconv22 = Conv2D(filters*8, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv22)
        shortcut6 = Conv2D(filters*8, kernel_size= (3,3), padding='same', strides=(1,1))(uconv2)
        shortcut6 = batch_Norm_Activation(shortcut6)
        output7 = add([uconv22,shortcut6])
        
    
        uconv3 = UpSampling2D((2,2))(output7)
        uconv3 = concatenate([uconv3, output2])
        
        uconv33 = batch_Norm_Activation(uconv3)
        uconv33 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv33)
        uconv33 = batch_Norm_Activation(uconv33)
        uconv33 = Conv2D(filters*4, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv33)
        shortcut7 = Conv2D(filters*4, kernel_size= (3,3), padding='same', strides=(1,1))(uconv3)
        shortcut7 = batch_Norm_Activation(shortcut7)
        output8 = add([uconv33,shortcut7])
        
        uconv4 = UpSampling2D((2,2))(output8)
        uconv4 = concatenate([uconv4, output1])
        
        uconv44 = batch_Norm_Activation(uconv4)
        uconv44 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv44)
        uconv44 = batch_Norm_Activation(uconv44)
        uconv44 = Conv2D(filters*2, kernel_size= (3,3), padding= 'same', strides=(1,1))(uconv44)
        shortcut8 = Conv2D(filters*2, kernel_size= (3,3), padding='same', strides=(1,1))(uconv4)
        shortcut8 = batch_Norm_Activation(shortcut8)
        output9 = add([uconv44,shortcut8])
        
        output_layer = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(output9)
        model = Model(inputs, output_layer)
        
        return model
    
    model = ResUnet2D(64, 512, 512)
    model.summary() 
    

       网络模型图

  • 相关阅读:
    IP 地址无效化
    上升下降字符串
    STL-----map
    只出现一次的数字
    4的幂
    GDI+_入门教程【一】
    大白话系列之C#委托与事件讲解(二)
    大白话系列之C#委托与事件讲解(二)
    大白话系列之C#委托与事件讲解(一)
    大白话系列之C#委托与事件讲解(一)
  • 原文地址:https://www.cnblogs.com/alex-bn-lee/p/15224922.html
Copyright © 2011-2022 走看看