zoukankan      html  css  js  c++  java
  • Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型

     

     最近做Kaggle的图像分类比赛:RSNA Intracranial Hemorrhage Detection (https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/overview)以及阅读Yolov3

    源码的时候接触到深度学习训练时一个有趣的技巧,那就是构造生成器generator 并且用Keras 的fit_generator来批量生成数据,释放内存,该方法适合于大规模数据集的训练。一个DataGenerator是keras的Sequence类的继承类,一般要包含__len__,__getitem__, on_epoch_end等方法,例如下面的批量图片数据生成器:

    class DataGenerator(keras.utils.Sequence):
          
          
          def __init__(self, list_IDs, labels, batch_size=1, img_size=(512, 512), 
                       img_dir, *args, **kwargs):
    
             """
                self.list_IDs:存放所有需要训练的图片文件名的列表。
                self.labels:记录图片标注的分类信息的pandas.DataFrame数据类型,已经预先给定。
                self.batch_size:每次批量生成,训练的样本大小。
                self.img_size:训练的图片尺寸。
                self.img_dir:图片在电脑中存放的路径。
          
          
             """
    
              
              self.list_IDs = list_IDs
              self.labels = labels
              self.batch_size = batch_size
              self.img_size = img_size
              self.img_dir = img_dir
              self.on_epoch_end()
    
          def __len__(self):
              
              """
                 返回生成器的长度,也就是总共分批生成数据的次数。
                 
              """
              return int(ceil(len(self.list_IDs) / self.batch_size))
    
         def __getitem__(self, index):
             
             """
                该函数返回每次我们需要的经过处理的数据。
             """
             
             indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
             list_IDs_temp = [self.list_IDs[k] for k in indices]
             X, Y = self.__data_generation(list_IDs_temp)
             return X, Y
    
         def on_epoch_end(self):
             
             """
                该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。
    
             """
             self.indices = np.arange(len(self.list_IDs))
             np.random.shuffle(self.indices)
    
         def __data_generation(self, list_IDs_temp):
    
            """
               给定文件名,生成数据。
            """
            X = np.empty((self.batch_size, *self.img_size, 1))
            Y = np.empty((self.batch_size, 6), dtype=np.float32)
    
           for i, ID in enumerate(list_IDs_temp):
           X[i,] = mpimg.imread(self.img_dir+ID+".png")
           Y[i,] = self.labels.loc[ID].values
    
           return X, Y

    有了这个生成器,我们就可以用fit_generator 方法进行训练,格式套路如下:

    model.fit_generator(generator,

    steps_per_epoch=...,

    epochs=...,

    verbose=...,

    callbacks=...,

    validation_data=...,

    validation_steps=...,

    validation_freq=...,

    class_weight=None=...,

    max_queue_size=...

    workers=...,

    use_multiprocessing=...,

    )

    除此以外我们还可以搞批量预测:

    model.predict_generator()

  • 相关阅读:
    安卓远程工具介绍及下载地址
    kylinos-kysec介绍
    远程控制工具ToDesk介绍
    kylinos桌面和服务器系统重置密码
    APT仓库目录和repository目录结构
    使用LVM实现动态磁盘管理
    如何实现访问http自动跳转https
    TypeScript学习 ———— 四、泛型
    TypeScript学习 ———— 三、function
    TypeScript学习 ———— 二、接口
  • 原文地址:https://www.cnblogs.com/szqfreiburger/p/11621261.html
Copyright © 2011-2022 走看看