zoukankan      html  css  js  c++  java
  • keras实现MobileNet

    利用keras实现MobileNet,并以mnist数据集作为一个小例子进行识别。使用的环境是:tensorflow-gpu 2.0,python=3.7 , GTX-2070的GPU

    1.导入数据

    • 首先是导入两行魔法命令,可以多行显示.
    %config InteractiveShell.ast_node_interactivity="all"
    %pprint
    
    • 加载keras中自带的mnist数据
    import tensorflow as tf
    import keras 
    
    tf.debugging.set_log_device_placement(True)
    
    mnist = keras.datasets.mnist
    
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    

    上述tf.debugging.set_log_device_placement(True)的作用是将模型放在GPU上进行训练。

    • 数据的转换
      在mnist上下载的数据的分辨率是2828的,mobilenet用来训练的数据是ImageNet ,其图片的分辨率是224224,所以先将图片的维度调整为224*224.
    from PIL import Image
    import numpy as np
    def convert_mnist_224pix(X):
        img=Image.fromarray(X)
        x=np.zeros((224,224))
        img=np.array(img.resize((224,224)))
        x[:,:]=img
        
        return x
    
    iteration = iter(x_train)
    new_train =np.zeros((len(x_train),224,224),dtype=np.float32)
    for i in range(len(x_train)):
        data = next(iteration)
        new_train[i]=convert_mnist_224pix(data)
        
        if i%5000==0:
            print(i)
        
    
    new_train.shape
    

    这里要注意一下,new_train中一定要注明dtype=np.float32,不然默认的是float64,这样数据就太大了,没有那么多存储空间装。最后输出的维度是(60000,224,224)

    2.搭建模型

    • 导入所有需要的函数和库
    from keras.layers import Conv2D,DepthwiseConv2D,Dense,AveragePooling2D,BatchNormalization,Input
    from keras import Model
    from keras import Sequential
    from keras.layers.advanced_activations import ReLU
    from keras.utils import to_categorical
    
    • 自己定义中间可以重复利用的层,将其放在一起,简化搭建网络的重复代码。
    def depth_point_conv2d(x,s=[1,1,2,1],channel=[64,128]):
        """
        s:the strides of the conv
        channel: the depth of pointwiseconvolutions
        """
        
        dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
        bn1 = BatchNormalization()(dw1)
        relu1 = ReLU()(bn1)
        pw1 = Conv2D(channel[0],(1,1),strides=s[1],padding='same')(relu1)
        bn2 = BatchNormalization()(pw1)
        relu2 = ReLU()(bn2)
        dw2 = DepthwiseConv2D((3,3),strides=s[2],padding='same')(relu2)
        bn3 = BatchNormalization()(dw2)
        relu3 = ReLU()(bn3)
        pw2 = Conv2D(channel[1],(1,1),strides=s[3],padding='same')(relu3)
        bn4 = BatchNormalization()(pw2)
        relu4 = ReLU()(bn4)
        
        return relu4
        
    def repeat_conv(x,s=[1,1],channel=512):
        dw1 = DepthwiseConv2D((3,3),strides=s[0],padding='same')(x)
        bn1 = BatchNormalization()(dw1)
        relu1 = ReLU()(bn1)
        pw1 = Conv2D(channel,(1,1),strides=s[1],padding='same')(relu1)
        bn2 = BatchNormalization()(pw1)
        relu2 = ReLU()(bn2)
        
        return relu2
        
    

    根据mobilenet论文中的结构进行模型的搭建
    MobileNet在倒数第5行Conv/dw/s2中,我一直不理解如果strides=2,为什么最后生成图片尺寸没有变化,我感觉可能是笔误?,不过我这里将strides定义为1,因为这样才符合后面的整个输出。

    • 搭建网络
    h0=Input(shape=(224,224,1))
    h1=Conv2D(32,(3,3),strides = 2,padding="same")(h0)
    h2= BatchNormalization()(h1)
    h3=ReLU()(h2)
    h4 = depth_point_conv2d(h3,s=[1,1,2,1],channel=[64,128])
    h5 = depth_point_conv2d(h4,s=[1,1,2,1],channel=[128,256])
    h6 = depth_point_conv2d(h5,s=[1,1,2,1],channel=[256,512])
    h7 = repeat_conv(h6)
    h8 = repeat_conv(h7)
    h9 = repeat_conv(h8)
    h10 = repeat_conv(h9)
    h11 = depth_point_conv2d(h10,s=[1,1,2,1],channel=[512,1024])
    h12 = repeat_conv(h11,channel=1024)
    h13 = AveragePooling2D((7,7))(h12)
    h14 = Dense(10,activation='softmax')(h13)
    model =Model(input=h0,output =h14)
    model.summary()
    
    Model: "model_4"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_11 (InputLayer)        (None, 224, 224, 1)       0         
    _________________________________________________________________
    conv2d_63 (Conv2D)           (None, 112, 112, 32)      320       
    _________________________________________________________________
    batch_normalization_120 (Bat (None, 112, 112, 32)      128       
    _________________________________________________________________
    re_lu_120 (ReLU)             (None, 112, 112, 32)      0         
    _________________________________________________________________
    depthwise_conv2d_58 (Depthwi (None, 112, 112, 32)      320       
    _________________________________________________________________
    batch_normalization_121 (Bat (None, 112, 112, 32)      128       
    _________________________________________________________________
    re_lu_121 (ReLU)             (None, 112, 112, 32)      0         
    _________________________________________________________________
    conv2d_64 (Conv2D)           (None, 112, 112, 64)      2112      
    _________________________________________________________________
    batch_normalization_122 (Bat (None, 112, 112, 64)      256       
    _________________________________________________________________
    re_lu_122 (ReLU)             (None, 112, 112, 64)      0         
    _________________________________________________________________
    depthwise_conv2d_59 (Depthwi (None, 56, 56, 64)        640       
    _________________________________________________________________
    batch_normalization_123 (Bat (None, 56, 56, 64)        256       
    _________________________________________________________________
    re_lu_123 (ReLU)             (None, 56, 56, 64)        0         
    _________________________________________________________________
    conv2d_65 (Conv2D)           (None, 56, 56, 128)       8320      
    _________________________________________________________________
    batch_normalization_124 (Bat (None, 56, 56, 128)       512       
    _________________________________________________________________
    re_lu_124 (ReLU)             (None, 56, 56, 128)       0         
    _________________________________________________________________
    depthwise_conv2d_60 (Depthwi (None, 56, 56, 128)       1280      
    _________________________________________________________________
    batch_normalization_125 (Bat (None, 56, 56, 128)       512       
    _________________________________________________________________
    re_lu_125 (ReLU)             (None, 56, 56, 128)       0         
    _________________________________________________________________
    conv2d_66 (Conv2D)           (None, 56, 56, 128)       16512     
    _________________________________________________________________
    batch_normalization_126 (Bat (None, 56, 56, 128)       512       
    _________________________________________________________________
    re_lu_126 (ReLU)             (None, 56, 56, 128)       0         
    _________________________________________________________________
    depthwise_conv2d_61 (Depthwi (None, 28, 28, 128)       1280      
    _________________________________________________________________
    batch_normalization_127 (Bat (None, 28, 28, 128)       512       
    _________________________________________________________________
    re_lu_127 (ReLU)             (None, 28, 28, 128)       0         
    _________________________________________________________________
    conv2d_67 (Conv2D)           (None, 28, 28, 256)       33024     
    _________________________________________________________________
    batch_normalization_128 (Bat (None, 28, 28, 256)       1024      
    _________________________________________________________________
    re_lu_128 (ReLU)             (None, 28, 28, 256)       0         
    _________________________________________________________________
    depthwise_conv2d_62 (Depthwi (None, 28, 28, 256)       2560      
    _________________________________________________________________
    batch_normalization_129 (Bat (None, 28, 28, 256)       1024      
    _________________________________________________________________
    re_lu_129 (ReLU)             (None, 28, 28, 256)       0         
    _________________________________________________________________
    conv2d_68 (Conv2D)           (None, 28, 28, 256)       65792     
    _________________________________________________________________
    batch_normalization_130 (Bat (None, 28, 28, 256)       1024      
    _________________________________________________________________
    re_lu_130 (ReLU)             (None, 28, 28, 256)       0         
    _________________________________________________________________
    depthwise_conv2d_63 (Depthwi (None, 14, 14, 256)       2560      
    _________________________________________________________________
    batch_normalization_131 (Bat (None, 14, 14, 256)       1024      
    _________________________________________________________________
    re_lu_131 (ReLU)             (None, 14, 14, 256)       0         
    _________________________________________________________________
    conv2d_69 (Conv2D)           (None, 14, 14, 512)       131584    
    _________________________________________________________________
    batch_normalization_132 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_132 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_64 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_133 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_133 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_70 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_134 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_134 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_65 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_135 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_135 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_71 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_136 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_136 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_66 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_137 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_137 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_72 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_138 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_138 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_67 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_139 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_139 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_73 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_140 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_140 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_68 (Depthwi (None, 14, 14, 512)       5120      
    _________________________________________________________________
    batch_normalization_141 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_141 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    conv2d_74 (Conv2D)           (None, 14, 14, 512)       262656    
    _________________________________________________________________
    batch_normalization_142 (Bat (None, 14, 14, 512)       2048      
    _________________________________________________________________
    re_lu_142 (ReLU)             (None, 14, 14, 512)       0         
    _________________________________________________________________
    depthwise_conv2d_69 (Depthwi (None, 7, 7, 512)         5120      
    _________________________________________________________________
    batch_normalization_143 (Bat (None, 7, 7, 512)         2048      
    _________________________________________________________________
    re_lu_143 (ReLU)             (None, 7, 7, 512)         0         
    _________________________________________________________________
    conv2d_75 (Conv2D)           (None, 7, 7, 1024)        525312    
    _________________________________________________________________
    batch_normalization_144 (Bat (None, 7, 7, 1024)        4096      
    _________________________________________________________________
    re_lu_144 (ReLU)             (None, 7, 7, 1024)        0         
    _________________________________________________________________
    depthwise_conv2d_70 (Depthwi (None, 7, 7, 1024)        10240     
    _________________________________________________________________
    batch_normalization_145 (Bat (None, 7, 7, 1024)        4096      
    _________________________________________________________________
    re_lu_145 (ReLU)             (None, 7, 7, 1024)        0         
    _________________________________________________________________
    conv2d_76 (Conv2D)           (None, 7, 7, 1024)        1049600   
    _________________________________________________________________
    batch_normalization_146 (Bat (None, 7, 7, 1024)        4096      
    _________________________________________________________________
    re_lu_146 (ReLU)             (None, 7, 7, 1024)        0         
    _________________________________________________________________
    average_pooling2d_5 (Average (None, 1, 1, 1024)        0         
    _________________________________________________________________
    dense_4 (Dense)              (None, 1, 1, 10)          10250     
    =================================================================
    Total params: 3,249,482
    Trainable params: 3,227,594
    Non-trainable params: 21,888
    _________________________________________________________________
    

    因为这里的类别只有10类,所以最后的输出层只有10个神经元,原始的mobilenet要进行1000个类别分类,所以最后是1000个神经元。

    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    

    上述代码定义优化算法和损失函数。

    3、训练数据的整理与训练

    将训练数据进行维度变换,标签进行one-hot编码并进行维度变换。

    x_train = np.expand_dims(new_train,3)
    
    y_train = to_categorical(y_train)
    
    y=np.expand_dims(y_train,1)
    y = np.expand_dims(y,1)
    
    • 定义数据生成函数
    def data_generate(x_train,y_train,batch_size,epochs):
        for i in range(epochs):
            batch_num = len(x_train)//batch_size
            shuffle_index = np.arange(batch_num)
            np.random.shuffle(shuffle_index)
            for j in shuffle_index:
                begin = j*batch_size
                end =begin+batch_size
                x = x_train[begin:end]
                y = y_train[begin:end]
                
                yield ({"input_11":x},{"dense_4":y})
                
    

    上述命名和model中的第一层和最后一层名字一样,不然会报错。

    • 开始训练
    model.fit_generator(data_generate(x_train,y,100,11),step_per_epoch=600,epochs=10)
    

    训练过程图如下:

    Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
    Epoch 1/10
    Executing op __inference_keras_scratch_graph_22639 in device /job:localhost/replica:0/task:0/device:GPU:0
    600/600 [==============================] - 411s 684ms/step - loss: 0.1469 - accuracy: 0.9529
    Epoch 2/10
    600/600 [==============================] - 398s 663ms/step - loss: 0.0375 - accuracy: 0.9884
    Epoch 3/10
    600/600 [==============================] - 401s 668ms/step - loss: 0.0283 - accuracy: 0.9909
    Epoch 4/10
    600/600 [==============================] - 399s 665ms/step - loss: 0.0211 - accuracy: 0.9936
    Epoch 5/10
    600/600 [==============================] - 400s 666ms/step - loss: 0.0216 - accuracy: 0.9932
    Epoch 6/10
    600/600 [==============================] - 401s 668ms/step - loss: 0.0208 - accuracy: 0.9935
    Epoch 7/10
    600/600 [==============================] - 401s 669ms/step - loss: 0.0174 - accuracy: 0.9945
    Epoch 8/10
    131/600 [=====>........................] - ETA: 5:13 - loss: 0.0091 - accuracy: 0.9973
    ​
    

    模型卷积比较多,需要训练的时间有点长,参数不多,所以更新较快,收敛速度也很快。

  • 相关阅读:
    使用Python操作InfluxDB时序数据库
    LogMysqlApeT
    内建函数 iter()
    Python魔法方法总结及注意事项
    Python魔法方法之属性访问 ( __getattr__, __getattribute__, __setattr__, __delattr__ )
    Python描述符 (descriptor) 详解
    在命令行模式下查看Python帮助文档---dir、help、__doc__
    python高并发的解决方案
    map中的erase成员函数用法
    指针的本质
  • 原文地址:https://www.cnblogs.com/zhou-lin/p/14047409.html
Copyright © 2011-2022 走看看