zoukankan      html  css  js  c++  java
  • Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型

     

     最近做Kaggle的图像分类比赛:RSNA Intracranial Hemorrhage Detection (https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/overview)以及阅读Yolov3

    源码的时候接触到深度学习训练时一个有趣的技巧,那就是构造生成器generator 并且用Keras 的fit_generator来批量生成数据,释放内存,该方法适合于大规模数据集的训练。一个DataGenerator是keras的Sequence类的继承类,一般要包含__len__,__getitem__, on_epoch_end等方法,例如下面的批量图片数据生成器:

    class DataGenerator(keras.utils.Sequence):
          
          
          def __init__(self, list_IDs, labels, batch_size=1, img_size=(512, 512), 
                       img_dir, *args, **kwargs):
    
             """
                self.list_IDs:存放所有需要训练的图片文件名的列表。
                self.labels:记录图片标注的分类信息的pandas.DataFrame数据类型,已经预先给定。
                self.batch_size:每次批量生成,训练的样本大小。
                self.img_size:训练的图片尺寸。
                self.img_dir:图片在电脑中存放的路径。
          
          
             """
    
              
              self.list_IDs = list_IDs
              self.labels = labels
              self.batch_size = batch_size
              self.img_size = img_size
              self.img_dir = img_dir
              self.on_epoch_end()
    
          def __len__(self):
              
              """
                 返回生成器的长度,也就是总共分批生成数据的次数。
                 
              """
              return int(ceil(len(self.list_IDs) / self.batch_size))
    
         def __getitem__(self, index):
             
             """
                该函数返回每次我们需要的经过处理的数据。
             """
             
             indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
             list_IDs_temp = [self.list_IDs[k] for k in indices]
             X, Y = self.__data_generation(list_IDs_temp)
             return X, Y
    
         def on_epoch_end(self):
             
             """
                该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。
    
             """
             self.indices = np.arange(len(self.list_IDs))
             np.random.shuffle(self.indices)
    
         def __data_generation(self, list_IDs_temp):
    
            """
               给定文件名,生成数据。
            """
            X = np.empty((self.batch_size, *self.img_size, 1))
            Y = np.empty((self.batch_size, 6), dtype=np.float32)
    
           for i, ID in enumerate(list_IDs_temp):
           X[i,] = mpimg.imread(self.img_dir+ID+".png")
           Y[i,] = self.labels.loc[ID].values
    
           return X, Y

    有了这个生成器,我们就可以用fit_generator 方法进行训练,格式套路如下:

    model.fit_generator(generator,

    steps_per_epoch=...,

    epochs=...,

    verbose=...,

    callbacks=...,

    validation_data=...,

    validation_steps=...,

    validation_freq=...,

    class_weight=None=...,

    max_queue_size=...

    workers=...,

    use_multiprocessing=...,

    )

    除此以外我们还可以搞批量预测:

    model.predict_generator()

  • 相关阅读:
    HDU4507 吉哥系列故事――恨7不成妻(数位dp)
    UCF Local Programming Contest 2017 G题(dp)
    ICPC Latin American Regional Contests 2019 I题
    UCF Local Programming Contest 2017 H题(区间dp)
    HDU2089 不要62
    AcWing1084 数字游戏II(数位dp)
    UCF Local Programming Contest 2017 F题(最短路)
    Google Code Jam 2019 Round 1A Pylons(爆搜+贪心)
    AcWing1083 Windy数(数位dp)
    Vue
  • 原文地址:https://www.cnblogs.com/szqfreiburger/p/11621261.html
Copyright © 2011-2022 走看看