zoukankan      html  css  js  c++  java
  • Keras神经网络data generators解决数据内存

        在使用kears训练model的时候,一般会将所有的训练数据加载到内存中,然后喂给网络,但当内存有限,且数据量过大时,此方法则不再可用。此博客,将介绍如何在多核(多线程)上实时的生成数据,并立即的送入到模型当中训练。 本篇文章由圆柱模板博主发布。

       先看一下还未改进的版本:

       

    import numpy as np
    from keras.models import Sequential
    #载入全部的数据!!
    X, y = np.load('some_training_set_with_labels.npy')
    #设计模型
    model = Sequential()
    [...] #网络结构
    model.compile()
    # 在数据集上进行模型训练
    model.fit(x=X, y=y)
    

      下面的结构将改变一次性载入全部数据的情况。接下来将介绍如何一步一步的构造数据生成器,此数据生成器也可应用在你自己的项目当中;复制下来,并根据自己的需求填充空白处。

        在构建之前先定义统一几个变量,并介绍几个小tips,对我们处理大的数据量很重要。 
    ID type为string,代表数据集中的某个样本。 
    调整以下结构,编译处理样本和他们的label:

        1.新建一个词典名叫 partition :

          

    partition[‘train’] 为训练集的ID,type为list
    partition[‘validation’] 为验证集的ID,type为list
    

      2.新建一个词典名叫 * labels * ,根据ID可找到数据集中的样本,同样可通过labels[ID]找到样本标签。 
    举个例子: 
    假设训练集包含三个样本,ID分别为id-1,id-2和id-3,相应的label分别为0,1,2。验证集包含样本ID id-4,标签为 1。此时两个词典partition和 labels分别如下:

        

    partition
    {'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']}
    

      

    labels
    {'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1}
    

      data/ 中为数据集文件。

       

    数据生成器(data generator)

    接下来将介绍如何构建数据生成器 DataGenerator ,DataGenerator将实时的对训练模型feed数据。 
    接下来,将先初始化类。我们使此类继承自keras.utils.Sequence,这样我们可以使用多线程。

      

    def __init__(self, list_IDs, labels, batch_size=32, 
                 dim=(32,32,32), n_channels=1,
                 n_classes=10, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()
    

      我们给了一些与数据相关的参数 dim,channels,classes,batch size ;方法 on_epoch_end 在一个epoch开始时或者结束时触发,shuffle决定是否在数据生成时要对数据进行打乱。

          

    def on_epoch_end(self):
      'Updates indexes after each epoch'
      self.indexes = np.arange(len(self.list_IDs))
      if self.shuffle == True:
          np.random.shuffle(self.indexes)
    

      另一个数据生成核心的方法__data_generation 是生成批数据。

       

    def __data_generation(self, list_IDs_temp):
      'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
      # Initialization
      X = np.empty((self.batch_size, *self.dim, self.n_channels))
      y = np.empty((self.batch_size), dtype=int)
    
      # Generate data
      for i, ID in enumerate(list_IDs_temp):
          # Store sample
          X[i,] = np.load('data/' + ID + '.npy')
    
          # Store class
          y[i] = self.labels[ID]
    
      return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
    

      在数据生成期间,代码读取包含各个样本ID的代码ID.py.因为我们的代码是可以应用多线程的,所以可以采用更为复杂的操作,不用担心数据生成成为总体效率的瓶颈。 
    另外,我们使用Keras的方法keras.utils.to_categorical对label进行2值化 
    (比如,对6分类而言,第三个label则相应的变成 to [0 0 1 0 0 0]) 。 

        

    def __len__(self):
      'Denotes the number of batches per epoch'
      return int(np.floor(len(self.list_IDs) / self.batch_size))
    

      现在,当相应的index的batch被选到,则生成器执行_getitem_方法来生成它。

        

    def __getitem__(self, index):
      'Generate one batch of data'
      # Generate indexes of the batch
      indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
    
      # Find list of IDs
      list_IDs_temp = [self.list_IDs[k] for k in indexes]
    
      # Generate data
      X, y = self.__data_generation(list_IDs_temp)
    
      return X, y
    

      

  • 相关阅读:
    Livepool
    Eclipse最新版注释模板设置详解
    hashcode详解
    开发集成工具MyEclipse中Outline的问题
    第三章 数据链路层(二)
    Java常考面试题(四)
    collections集合的总括。
    第三章 数据链路层(一)
    Java常考面试题(三)
    Java常考面试题(二)
  • 原文地址:https://www.cnblogs.com/68xi/p/8661077.html
Copyright © 2011-2022 走看看