zoukankan      html  css  js  c++  java
  • 实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题

    实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题

    一、实践流程

    1、数据预处理

    主要是对训练数据进行随机偏移、转动等变换图像处理,这样可以尽可能让训练数据多样化

    另外处理数据方式采用分批无序读取的形式,避免了数据按目录排序训练

    1.  
      #数据准备
    2.  
      def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
    3.  
      if is_train:
    4.  
      datagen = ImageDataGenerator(rescale=1./255,
    5.  
      zoom_range=0.25, rotation_range=15.,
    6.  
      channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
    7.  
      horizontal_flip=True, fill_mode='constant')
    8.  
      else:
    9.  
      datagen = ImageDataGenerator(rescale=1./255)
    10.  
       
    11.  
      generator = datagen.flow_from_directory(
    12.  
      dir_path, target_size=(img_row, img_col),
    13.  
      batch_size=batch_size,
    14.  
      shuffle=is_train)
    15.  
       
    16.  
      return generator
    2、载入现有模型

    这个部分是核心工作,目的是使用ImageNet训练出的权重来做我们的特征提取器,注意这里后面的分类层去掉

    1.  
      base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
    2.  
      input_shape=(img_rows, img_cols, color),
    3.  
      classes=nb_classes)

    然后是冻结这些层,因为是训练好的

    1.  
      for layer in base_model.layers:
    2.  
      layer.trainable = False
    而分类部分,需要我们根据现有需求来新定义的,这里可以根据实际情况自己进行调整,比如这样
    1.  
      x = base_model.output
    2.  
      # 添加自己的全链接分类层
    3.  
      x = GlobalAveragePooling2D()(x)
    4.  
      x = Dense(1024, activation='relu')(x)
    5.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    或者
    1.  
      x = base_model.output
    2.  
      #添加自己的全链接分类层
    3.  
      x = Flatten()(x)
    4.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    3、训练模型

    这里我们用fit_generator函数,它可以避免了一次性加载大量的数据,并且生成器与模型将并行执行以提高效率。比如可以在CPU上进行实时的数据提升,同时在GPU上进行模型训练

    1.  
      history_ft = model.fit_generator(
    2.  
      train_generator,
    3.  
      steps_per_epoch=steps_per_epoch,
    4.  
      epochs=epochs,
    5.  
      validation_data=validation_generator,
    6.  
      validation_steps=validation_steps)

    二、猫狗大战数据集

    训练数据540M,测试数据270M,大家可以去官网下载

    https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data

    下载后把数据分成dog和cat两个目录来存放

    三、训练

    训练的时候会自动去下权值,比如vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5,但是如果我们已经下载好了的话,可以改源代码,让他直接读取我们的下载好的权值,比如在resnet50.py中

    1、VGG19

    vgg19的深度有26层,参数达到了549M,原模型最后有3个全连接层做分类器所以我还是加了一个1024的全连接层,训练10轮的情况达到了89%

    2、ResNet50

    ResNet50的深度达到了168层,但是参数只有99M,分类模型我就简单点,一层直接分类,训练10轮的达到了96%的准确率

    3、inception_v3

    InceptionV3的深度159层,参数92M,训练10轮的结果

    这是一层直接分类的结果

    这是加了一个512全连接的,大家可以随意调整测试

    四、完整的代码

    1.  
      # -*- coding: utf-8 -*-
    2.  
      import os
    3.  
      from keras.utils import plot_model
    4.  
      from keras.applications.resnet50 import ResNet50
    5.  
      from keras.applications.vgg19 import VGG19
    6.  
      from keras.applications.inception_v3 import InceptionV3
    7.  
      from keras.layers import Dense,Flatten,GlobalAveragePooling2D
    8.  
      from keras.models import Model,load_model
    9.  
      from keras.optimizers import SGD
    10.  
      from keras.preprocessing.image import ImageDataGenerator
    11.  
      import matplotlib.pyplot as plt
    12.  
       
    13.  
      class PowerTransferMode:
    14.  
      #数据准备
    15.  
      def DataGen(self, dir_path, img_row, img_col, batch_size, is_train):
    16.  
      if is_train:
    17.  
      datagen = ImageDataGenerator(rescale=1./255,
    18.  
      zoom_range=0.25, rotation_range=15.,
    19.  
      channel_shift_range=25., width_shift_range=0.02, height_shift_range=0.02,
    20.  
      horizontal_flip=True, fill_mode='constant')
    21.  
      else:
    22.  
      datagen = ImageDataGenerator(rescale=1./255)
    23.  
       
    24.  
      generator = datagen.flow_from_directory(
    25.  
      dir_path, target_size=(img_row, img_col),
    26.  
      batch_size=batch_size,
    27.  
      #class_mode='binary',
    28.  
      shuffle=is_train)
    29.  
       
    30.  
      return generator
    31.  
       
    32.  
      #ResNet模型
    33.  
      def ResNet50_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
    34.  
      color = 3 if RGB else 1
    35.  
      base_model = ResNet50(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
    36.  
      classes=nb_classes)
    37.  
       
    38.  
      #冻结base_model所有层,这样就可以正确获得bottleneck特征
    39.  
      for layer in base_model.layers:
    40.  
      layer.trainable = False
    41.  
       
    42.  
      x = base_model.output
    43.  
      #添加自己的全链接分类层
    44.  
      x = Flatten()(x)
    45.  
      #x = GlobalAveragePooling2D()(x)
    46.  
      #x = Dense(1024, activation='relu')(x)
    47.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    48.  
       
    49.  
      #训练模型
    50.  
      model = Model(inputs=base_model.input, outputs=predictions)
    51.  
      sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
    52.  
      model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    53.  
       
    54.  
      #绘制模型
    55.  
      if is_plot_model:
    56.  
      plot_model(model, to_file='resnet50_model.png',show_shapes=True)
    57.  
       
    58.  
      return model
    59.  
       
    60.  
       
    61.  
      #VGG模型
    62.  
      def VGG19_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True, is_plot_model=False):
    63.  
      color = 3 if RGB else 1
    64.  
      base_model = VGG19(weights='imagenet', include_top=False, pooling=None, input_shape=(img_rows, img_cols, color),
    65.  
      classes=nb_classes)
    66.  
       
    67.  
      #冻结base_model所有层,这样就可以正确获得bottleneck特征
    68.  
      for layer in base_model.layers:
    69.  
      layer.trainable = False
    70.  
       
    71.  
      x = base_model.output
    72.  
      #添加自己的全链接分类层
    73.  
      x = GlobalAveragePooling2D()(x)
    74.  
      x = Dense(1024, activation='relu')(x)
    75.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    76.  
       
    77.  
      #训练模型
    78.  
      model = Model(inputs=base_model.input, outputs=predictions)
    79.  
      sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
    80.  
      model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    81.  
       
    82.  
      # 绘图
    83.  
      if is_plot_model:
    84.  
      plot_model(model, to_file='vgg19_model.png',show_shapes=True)
    85.  
       
    86.  
      return model
    87.  
       
    88.  
      # InceptionV3模型
    89.  
      def InceptionV3_model(self, lr=0.005, decay=1e-6, momentum=0.9, nb_classes=2, img_rows=197, img_cols=197, RGB=True,
    90.  
      is_plot_model=False):
    91.  
      color = 3 if RGB else 1
    92.  
      base_model = InceptionV3(weights='imagenet', include_top=False, pooling=None,
    93.  
      input_shape=(img_rows, img_cols, color),
    94.  
      classes=nb_classes)
    95.  
       
    96.  
      # 冻结base_model所有层,这样就可以正确获得bottleneck特征
    97.  
      for layer in base_model.layers:
    98.  
      layer.trainable = False
    99.  
       
    100.  
      x = base_model.output
    101.  
      # 添加自己的全链接分类层
    102.  
      x = GlobalAveragePooling2D()(x)
    103.  
      x = Dense(1024, activation='relu')(x)
    104.  
      predictions = Dense(nb_classes, activation='softmax')(x)
    105.  
       
    106.  
      # 训练模型
    107.  
      model = Model(inputs=base_model.input, outputs=predictions)
    108.  
      sgd = SGD(lr=lr, decay=decay, momentum=momentum, nesterov=True)
    109.  
      model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    110.  
       
    111.  
      # 绘图
    112.  
      if is_plot_model:
    113.  
      plot_model(model, to_file='inception_v3_model.png', show_shapes=True)
    114.  
       
    115.  
      return model
    116.  
       
    117.  
      #训练模型
    118.  
      def train_model(self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model=False):
    119.  
      # 载入模型
    120.  
      if is_load_model and os.path.exists(model_url):
    121.  
      model = load_model(model_url)
    122.  
       
    123.  
      history_ft = model.fit_generator(
    124.  
      train_generator,
    125.  
      steps_per_epoch=steps_per_epoch,
    126.  
      epochs=epochs,
    127.  
      validation_data=validation_generator,
    128.  
      validation_steps=validation_steps)
    129.  
      # 模型保存
    130.  
      model.save(model_url,overwrite=True)
    131.  
      return history_ft
    132.  
       
    133.  
      # 画图
    134.  
      def plot_training(self, history):
    135.  
      acc = history.history['acc']
    136.  
      val_acc = history.history['val_acc']
    137.  
      loss = history.history['loss']
    138.  
      val_loss = history.history['val_loss']
    139.  
      epochs = range(len(acc))
    140.  
      plt.plot(epochs, acc, 'b-')
    141.  
      plt.plot(epochs, val_acc, 'r')
    142.  
      plt.title('Training and validation accuracy')
    143.  
      plt.figure()
    144.  
      plt.plot(epochs, loss, 'b-')
    145.  
      plt.plot(epochs, val_loss, 'r-')
    146.  
      plt.title('Training and validation loss')
    147.  
      plt.show()
    148.  
       
    149.  
       
    150.  
      if __name__ == '__main__':
    151.  
      image_size = 197
    152.  
      batch_size = 32
    153.  
       
    154.  
      transfer = PowerTransferMode()
    155.  
       
    156.  
      #得到数据
    157.  
      train_generator = transfer.DataGen('data/cat_dog_Dataset/train', image_size, image_size, batch_size, True)
    158.  
      validation_generator = transfer.DataGen('data/cat_dog_Dataset/test', image_size, image_size, batch_size, False)
    159.  
       
    160.  
      #VGG19
    161.  
      #model = transfer.VGG19_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
    162.  
      #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model=False)
    163.  
       
    164.  
      #ResNet50
    165.  
      model = transfer.ResNet50_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=False)
    166.  
      history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model=False)
    167.  
       
    168.  
      #InceptionV3
    169.  
      #model = transfer.InceptionV3_model(nb_classes=2, img_rows=image_size, img_cols=image_size, is_plot_model=True)
    170.  
      #history_ft = transfer.train_model(model, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model=False)
    171.  
       
    172.  
      # 训练的acc_loss图
    173.  
      transfer.plot_training(history_ft)
  • 相关阅读:
    51nod 1179 最大的最大公约数 (数论)
    POJ 3685 二分套二分
    POJ 3045 贪心
    LIC
    HDU 1029 Ignatius and the Princess IV
    HDU 1024 Max Sum Plus Plus
    HDU 2389 Rain on your Parade
    HDU 2819 Swap
    HDU 1281 棋盘游戏
    HDU 1083 Courses
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/11231748.html
Copyright © 2011-2022 走看看