zoukankan      html  css  js  c++  java
  • 语义分割

    1.vgg_segnet

    from keras.models import *
    from keras.layers import *
    from keras.activations import *
    import keras.backend as K
    import keras
    IMAGE_ORDERRING='channels_last'
    #1.encoder_0.5net
    def convnet_encoder(input_height=416,input_width=416,pretained='imagenet'):
        img_input=Input(shape=(input_height,input_width,3))
        x=Conv2D(64,(3,3),activation='relu',padding='same',name='block1_conv1')(img_input)
        x=Conv2D(64,(3,3),activation='relu',padding='same',name='block1_conv2')(x)
        x=MaxPooling2D((2,2),strides=(2,2),name='block_pool')(x)
        f1=x
        x=Conv2D(128,(3,3),activation='relu',padding='same',name='block2_conv1')(x)
        x=Conv2D(128,(3,3),activation='relu',padding='same',name='block2_conv2')(x)
        x=MaxPooling2D((2,2),strides=(2,2),name='block2_pool')(x)
        f2=x
        x=Conv2D(256,(3,3),activation='relu',padding='same',name='block3_conv1')(x)
        x=Conv2D(256,(3,3),activation='relu',padding='same',name='block3_conv2')(x)
        x=Conv2D(256,(3,3),activation='relu',padding='same',name='block3_conv3')(x)
        x=MaxPooling2D((2,2),strides=(2,2),name='block3_pool')(x)
        f3=x
        x=Conv2D(512,(3,3),activation='relu',padding='same',name='block4_conv1')(x)
        x=Conv2D(512,(3,3),activation='relu',padding='same',name='block4_conv2')(x)
        x=Conv2D(512,(3,3),activation='relu',padding='same',name='block4_conv3')(x)
        x=MaxPooling2D((2,2),strides=(2,2),name='block4_pool')(x)
        f4=x
        x=Conv2D(1024,(3,3),activation='relu',padding='same',name='block5_conv1')(x)
        x=Conv2D(1024,(3,3),activation='relu',padding='same',name='block5_conv2')(x)
        x=Conv2D(1024,(3,3),activation='relu',padding='same',name='block5_conv3')(x)
        x=MaxPooling2D((2,2),strides=(2,2),name='block5_pool')(x)
        f5=x
        return img_input,[f1,f2,f3,f4,f5]
    

      

    #2.decoder_0.5net
    def segnet_decoder(output_input,n_classes):
        x=ZeroPadding2D((1,1))(output_input)
        x=Conv2D(512,(3,3),padding='valid')(x)
        x=BatchNormalization()(x)
        x=UpSampling2D((2,2))(x)
        x=ZeroPadding2D((1,1))(x)
        x=Conv2D(256,(3,3),padding='valid')(x)
        x=BatchNormalization()(x)
        x=UpSampling2D((2,2))(x)
        x=ZeroPadding2D((1,1))(x)
        x=Conv2D(128,(3,3),padding='valid')(x)
        x=BatchNormalization()(x)
        x=UpSampling2D((2,2) )(x)
        x=ZeroPadding2D((1,1))(x)
        x=Conv2D(64,(3,3),padding='valid')(x)
        x=BatchNormalization()(x)
        x=Conv2D(n_classes,(3,3),padding='same')(x)
        return x
    

      

    #3.vgg_segnet_net
    def convnet_segnet(n_classes,input_height=416,input_width=416):
        img_input,levels=convnet_encoder(input_height=input_height,input_width=input_width)
        feat=levels[3]#f4
        print(feat.shape)
        x=segnet_decoder(feat,n_classes)
        #将结果reshape到2维,就是每个像素点的预测类别
        x=Reshape((int(input_height/2)*int(input_width/2),-1))(x)#(每个像素点,类别onehot)
        out=Softmax()(x)#概率最大类别
        model=Model(img_input,out)#构建模型(输入输出)
        model.model_name='convnet_segnet'
        return model
    #查看网络结构
    model=convnet_segnet(2,input_height=416,input_width=416)
    model.summary()
    

      

  • 相关阅读:
    常用的算法
    2017前端面试题
    深入了解php opcode缓存原理
    0=='aa'的结果是true
    关于PHP浮点数之 intval((0.1+0.7)*10) 为什么是7
    linux grep命令
    linux awk命令详解
    PHP socket模拟POST请求
    shell编程之sed
    Shell脚本常用判断
  • 原文地址:https://www.cnblogs.com/Turing-dz/p/13331104.html
Copyright © 2011-2022 走看看