zoukankan      html  css  js  c++  java
  • keras实现MobileNet

    利用keras实现MobileNet,并以mnist数据集作为一个小例子进行识别。使用的环境是:tensorflow-gpu 2.0,python=3.7 , GTX-2070的GPU

    1.导入数据

    • 首先是导入两行魔法命令,可以多行显示.
    %config InteractiveShell.ast_node_interactivity="all"
    %pprint
    
    • 加载keras中自带的mnist数据
    import tensorflow as tf
    import keras 
    
    tf.debugging.set_log_device_placement(True)
    
    mnist = keras.datasets.mnist
    
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    

    上述tf.debugging.set_log_device_placement(True)的作用是将模型放在GPU上进行训练。

    • 数据的转换
      在mnist上下载的数据的分辨率是2828的,mobilenet用来训练的数据是ImageNet ,其图片的分辨率是224224,所以先将图片的维度调整为224*224.
    from PIL import Image
    import numpy as np
    def convert_mnist_224pix(X):
        img=Image.fromarray(X)
        x=np.zeros((224,224))
        img=np.array(img.resize((224,224)))
        x[:,:]=img
        
        return x
    
    iteration = iter(x_train)
    new_train =np.zeros((len(x_train),224,224),dtype=np.float32)
    for i in range(len(x_train)):
        data = next(iteration)
        new_train[i]=convert_mnist_224pix(data)
        
        if i%5000==0:
            print(i)
        
    
    new_train.shape
    

    这里要注意一下,new_train中一定要注明dtype=np.float32,不然默认的是float64,这样数据就太大了,没有那么多存储空间装。最后输出的维度是(60000,224,224)

    2.搭建模型

    • 导入所有需要的函数和库
    from keras.layers import Conv2D,DepthwiseConv2D,Dense,AveragePooling2D,BatchNormalization,Input
    from keras import Model
    from keras import Sequential
    from keras.layers.advanced_activations import ReLU
    from keras.utils import to_categorical
    
    • 自己定义中间可以重复利用的层,将其放在一起,简化搭建网络的重复代码。
    def depth_point_conv2d(x,s=[1,1,2,1],channel=[64,128]):
        """
        s:the strides of the conv
        channel: the depth of pointwiseconvolutions
        """
        
        dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
        bn1 = BatchNormalization()(dw1)
        relu1 = ReLU()(bn1)
        pw1 = Conv2D(channel[0],(1,1),strides=s[1],padding='same')(relu1)
        bn2 = BatchNormalization()(pw1)
        relu2 = ReLU()(bn2)
        dw2 = DepthwiseConv2D((3,3),strides=s[2],padding='same')(relu2)
        bn3 = BatchNormalization()(dw2)
        relu3 = ReLU()(bn3)
        pw2 = Conv2D(channel[1],(1,1),strides=s[3],padding='same')(relu3)
        bn4 = BatchNormalization()(pw2)
        relu4 = ReLU()(bn4)
        
        return relu4
        
    def repeat_conv(x,s=[1,1],channel=512):
        dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
        bn1 = BatchNormalization()(dw1)
        relu1 = ReLU()(bn1)
        pw1 = Conv2D(channel,(1,1),strides=s[1],padding='same')(relu1)
        bn2 = BatchNormalization()(pw1)
        relu2 = ReLU()(bn2)
        
        return relu2
        
    

    根据mobilenet论文中的结构进行模型的搭建
    MobileNet在倒数第5行Conv/dw/s2中,我一直不理解如果strides=2,为什么最后生成图片尺寸没有变化,我感觉可能是笔误?,不过我这里将strides定义为1,因为这样才符合后面的整个输出。

    • 搭建网络
    h0=Input(shape=(224,224,1))
    h1=Conv2D(32,(3,3),strides = 2,padding="same")(h0)
    h2= BatchNormalization()(h1)
    h3=ReLU()(h2)
    h4 = depth_point_conv2d(h3,s=[1,1,2,1],channel=[64,128])
    h5 = depth_point_conv2d(h4,s=[1,1,2,1],channel=[128,256])
    h6 = depth_point_conv2d(h5,s=[1,1,2,1],channel=[256,512])
    h7 = repeat_conv(h6)
    h8 = repeat_conv(h7)
    h9 = repeat_conv(h8)
    h10 = repeat_conv(h9)
    h11 = depth_point_conv2d(h10,s=[1,1,2,1],channel=[512,1024])
    h12 = repeat_conv(h11,channel=1024)
    h13 = AveragePooling2D((7,7))(h12)
    h14 = Dense(10,activation='softmax')(h13)
    model =Model(input=h0,output =h14)
    model.summary()
    
    Model: "model_4"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_11 (InputLayer)        (None, 224, 224, 1)       0         
    _________________________________________________________________
    conv2d_63 (Conv2D)           (None, 112, 112, 32)      320       
    _________________________________________________________________
    batch_normalization_120 (Bat (None, 112, 112, 32)      128       
    _________________________________________________________________
    re_lu_120 (ReLU)             (None, 112, 112, 32)      0         
    _________________________________________________________________
    depthwise_conv2d_58 (Depthwi (None, 112, 112, 32)      320       
    _________________________________________________________________
    batch_normalization_121 (Bat (None, 112, 112, 32)      128       
    _________________________________________________________________
    re_lu_121 (ReLU)             (None, 112, 112, 32)      0         
    _________________________________________________________________
    conv2d_64 (Conv2D)           (None, 112, 112, 64)      2112      
    _________________________________________________________________
    batch_normalization_122 (Bat (None, 112, 112, 64)      256       
    _________________________________________________________________
    re_lu_122 (ReLU)             (None, 112, 112, 64)      0         
    _________________________________________________________________
    depthwise_conv2d_59 (Depthwi (None, 56, 56, 64)        640       
    _________________________________________________________________
    batch_normalization_123 (Bat (None, 56, 56, 64)        256       
    _________________________________________________________________
    re_lu_123 (ReLU)             (None, 56, 56, 64)        0         
    _________________________________________________________________
    conv2d_65 (Conv2D)           (None, 56, 56, 128)       8320      
    _________________________________________________________________
    batch_normalization_124 (Bat (None, 56, 56, 128)       512       
    _________________________________________________________________
    re_lu_124 (ReLU)             (None, 56, 56, 128)       0         
    _________________________________________________________________
    depthwise_conv2d_60 (Depthwi (None, 56, 56, 128)       1280      
    _________________________________________________________________
    batch_normalization_125 (Bat (None, 56, 56, 128)       512       
    _________________________________________________________________
    re_lu_125 (ReLU)             (None, 56, 56, 128)       0         
    _________________________________________________________________
    conv2d_66 (Conv2D)           (None, 56, 56, 128)       16512     
    _________________________________________________________________
    batch_normalization_126 (Bat (None, 56, 56, 128)       512       
    _________________________________________________________________
    re_lu_126 (ReLU)             (None, 56, 56, 128)       0         
    _________________________________________________________________
    depthwise_conv2d_61 (Depthwi (None, 28, 28, 128)       1280      
    _________________________________________________________________
    batch_normalization_127 (Bat (None, 28, 28, 128)       512       
    _________________________________________________________________
    re_lu_127 (ReLU)             (None, 28, 28, 128)       0         
    _________________________________________________________________
    conv2d_67 (Conv2D)           (None, 28, 28, 256)       33024     
    _________________________________________________________________
    batch_normalization_128 (Bat (None, 28, 28, 256)       1024      
    _________________________________________________________________
    re_lu_128 (ReLU)             (None, 28, 28, 256)       0         
    _________________________________________________________________
    depthwise_conv2d_62 (Depthwi (None, 28, 28, 256)       2560      
    _________________________________________________________________
    batch_normalization_129 (Bat (None, 28, 28, 256)       1024      
    _________________________________________________________________
    re_lu_129 (ReLU)             (None, 28, 28, 256)       0         
    _________________________________________________________________
    conv2d_68 (Conv2D)           (None, 28, 28, 256)       65792     
    _________________________________________________________________
    batch_normalization_130 (Bat (None, 28, 28, 256)       1024      
    _________________________________________________________________
    re_lu_130 (ReLU)             (None, 28, 28, 256)       0         
    _________________________________________________________________
    depthwise_conv2d_63 (Depthwi (None, 14, 14, 256)       2560      
    _________________________________________________________________
    batch_normalization_131 (Bat (None, 14, 14, 256)       1024      
    _________________________________________________________________
    re_lu_131 (ReLU)             (None, 14, 14, 256)       0         
    _________________________________________________________________
    conv2d_69 (Conv2D)           (None, 14, 14, 512)       131584    
    _________________________________________________________________
    batch_normalization_132 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_132 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_64 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_133 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_133 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_70 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_134 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_134 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_65 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_135 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_135 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_71 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_136 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_136 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_66 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_137 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_137 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_72 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_138 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_138 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_67 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_139 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_139 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_73 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_140 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_140 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_68 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_141 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_141 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_74 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_142 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_142 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_69 (Depthwi (None, 7, 7, 512)         5120      
    _________________________________________________________________
    batch_normalization_143 (Bat (None, 7, 7, 512)         2048      
    _________________________________________________________________
    re_lu_143 (ReLU)             (None, 7, 7, 512)         0         
    _________________________________________________________________
    conv2d_75 (Conv2D)           (None, 7, 7, 1024)        525312    
    _________________________________________________________________
    batch_normalization_144 (Bat (None, 7, 7, 1024)        4096      
    _________________________________________________________________
    re_lu_144 (ReLU)             (None, 7, 7, 1024)        0         
    _________________________________________________________________
    depthwise_conv2d_70 (Depthwi (None, 7, 7, 1024)        10240     
    _________________________________________________________________
    batch_normalization_145 (Bat (None, 7, 7, 1024)        4096      
    _________________________________________________________________
    re_lu_145 (ReLU)             (None, 7, 7, 1024)        0         
    _________________________________________________________________
    conv2d_76 (Conv2D)           (None, 7, 7, 1024)        1049600   
    _________________________________________________________________
    batch_normalization_146 (Bat (None, 7, 7, 1024)        4096      
    _________________________________________________________________
    re_lu_146 (ReLU)             (None, 7, 7, 1024)        0         
    _________________________________________________________________
    average_pooling2d_5 (Average (None, 1, 1, 1024)        0         
    _________________________________________________________________
    dense_4 (Dense)              (None, 1, 1, 10)          10250     
    =================================================================
    Total params: 3,249,482
    Trainable params: 3,227,594
    Non-trainable params: 21,888
    _________________________________________________________________
    

    因为这里的类别只有10类,所以最后的输出层只有10个神经元,原始的mobilenet要进行1000个类别分类,所以最后是1000个神经元。

    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    

    上述代码定义优化算法和损失函数。

    3、训练数据的整理与训练

    将训练数据进行维度变换,标签进行one-hot编码并进行维度变换。

    x_train = np.expand_dims(new_train,3)
    
    y_train = to_categorical(y_train)
    
    y=np.expand_dims(y_train,1)
    y = np.expand_dims(y,1)
    
    • 定义数据生成函数
    def data_generate(x_train,y_train,batch_size,epochs):
        for i in range(epochs):
            batch_num = len(x_train)//batch_size
            shuffle_index = np.arange(batch_num)
            np.random.shuffle(shuffle_index)
            for j in shuffle_index:
                begin = j*batch_size
                end =begin+batch_size
                x = x_train[begin:end]
                y = y_train[begin:end]
                
                yield ({"input_11":x},{"dense_4":y})
                
    

    上述命名和model中的第一层和最后一层名字一样,不然会报错。

    • 开始训练
    model.fit_generator(data_generate(x_train,y,100,11),step_per_epoch=600,epochs=10)
    

    训练过程图如下:

    Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
    Epoch 1/10
    Executing op __inference_keras_scratch_graph_22639 in device /job:localhost/replica:0/task:0/device:GPU:0
    600/600 [==============================] - 411s 684ms/step - loss: 0.1469 - accuracy: 0.9529
    Epoch 2/10
    600/600 [==============================] - 398s 663ms/step - loss: 0.0375 - accuracy: 0.9884
    Epoch 3/10
    600/600 [==============================] - 401s 668ms/step - loss: 0.0283 - accuracy: 0.9909
    Epoch 4/10
    600/600 [==============================] - 399s 665ms/step - loss: 0.0211 - accuracy: 0.9936
    Epoch 5/10
    600/600 [==============================] - 400s 666ms/step - loss: 0.0216 - accuracy: 0.9932
    Epoch 6/10
    600/600 [==============================] - 401s 668ms/step - loss: 0.0208 - accuracy: 0.9935
    Epoch 7/10
    600/600 [==============================] - 401s 669ms/step - loss: 0.0174 - accuracy: 0.9945
    Epoch 8/10
    131/600 [=====>........................] - ETA: 5:13 - loss: 0.0091 - accuracy: 0.9973
    ​
    

    模型卷积比较多,需要训练的时间有点长,参数不多,所以更新较快,收敛速度也很快。

  • 相关阅读:
    suse12安装详解
    Centos7上部署openstack mitaka配置详解(将疑难点都进行划分)
    菜鸟帮你跳过openstack配置过程中的坑[文末新添加福利]
    openstack中dashboard页面RuntimeError: Unable to create a new session key. It is likely that the cache is unavailable.
    Multiple network matches found for name 'selfservice', use an ID to be more specific.报错
    查看 SELinux状态及关闭SELinux
    SELinux深入理解
    IP地址、子网掩码、网络号、主机号、网络地址、主机地址
    Oracle job procedure 存储过程定时任务
    POI文件导出至EXCEL,并弹出下载框
  • 原文地址:https://www.cnblogs.com/zhou-lin/p/14047409.html
Copyright © 2011-2022 走看看