zoukankan      html  css  js  c++  java
  • keras fit_generator 并行


    虽然已经走在 torch boy 的路上了, 还是把碰到的这个坑给记录一下

    • 数据量较小时,我们可直接把整个数据集 load 到内存里,用 model.fit() 来拟合模型。
    • 当数据集过大比如几十个 G 时,内存撑不下,需要用 model.fit_generator 的方式来拟合。

    model.fit_generator 一般参数的配置参考官方文档就好,其中 generator, workers, use_multiprocessing 的使用有一些坑存在。

    workers=0, use_multiprocessing=False

    此时 generator 用一个普通的 generator去提供数据即可,类似官方提供的这种

    def generate_arrays_from_file(path):
        while True:
            with open(path) as f:
                for line in f:
                    # create numpy arrays of input data
                    # and labels, from each line in the file
                    x1, x2, y = process_line(line)
                    yield ({'input_1': x1, 'input_2': x2}, {'output': y})
    
    model.fit_generator(generate_arrays_from_file('/my_file.txt'),
                        steps_per_epoch=10000, epochs=10)
    

    workers>0, use_multiprocessing=True

    这时依然用一个 generator function 来做 generator在拟合的时候便会报错如下:

    PicklingError: Can't pickle <function generator_queue.<locals>.data_generator_task at
    

    且当 use_multiprocessing=True 时,如果你使用的是 generator function, 代码会把你的数据copy几份分给不同的worker去处理,但我们希望的是把一份数据平均分拆成几份给多个worker去处理。

    怎么解决上面两个问题? keras.utils.Sequence 可以做到

    很简单,继承 keras.utils.Sequence 这个类,重写自己的 len(), getitem 即可。

    class SequenceData(Sequence):
        def __init__(self, filePaths, batch_size):
            self.filePaths = filePaths[:100].copy()
            self.batch_size = batch_size
            self.Y = self.getY()
    
        def __len__(self):
            return len(self.Y) // self.batch_size
    
        def __getitem__(self, index):
            batch_X = np.zeros((self.batch_size,) + IMG_DIMS, dtype='float32')
            batch_Y_ = self.Y[index*self.batch_size: (index+1)*self.batch_size].copy()
            batch_Y_.reset_index(drop=True, inplace=True)
            assert batch_Y_.shape[0] == self.batch_size
    
            for index, rows in batch_Y_.iterrows():
                try:
                    img = _load_img(rows['path'])
                    batch_X[index, :, :, :] = img.copy()
                    batch_Y_.loc[index, 'valid'] = 1
                except:
                    batch_Y_.loc[index, 'valid'] = 0
                    traceback.print_exc()
            batch_Y = to_categorical(batch_Y_['label'], classes_num)
            return batch_X, batch_Y
    
        def __iter__(self):
            for item in (self[i] for i in range(len(self))):
                yield item
    
        def getY(self):
            Y = pd.DataFrame(self.filePaths, columns=['path'])
            Y['class'] = Y['path'].apply(lambda x: path2class(x))
            Y['label'] = Y['class'].apply(lambda x: class2label[x])
            Y = Y.sample(frac=1).reset_index(drop=True)
            return Y
    

    效果比较

    • 样本量:1000张图片
    • 模型: MobileNetV2
    • epochs: 5
    • CPU: 4核,3.4GHz
    • GPU: None

    可能数据量过小,并行的效果不是太明显。

    数据读取方式 workers use_multiprocessing 耗时/s
    内存读取 0 True 1797
    keras.utils.Sequence 0 False 1475
    keras.utils.Sequence 4 True

    参考:

  • 相关阅读:
    android pcm
    mongo DB的一般操作
    使用SQL Server 扩展事件来创建死锁的时间跟踪
    sql 日期格式汇总
    简述SQL2008部署多实例集群(学习)
    数据库压缩备份提高备份效率
    SSRS报表连接超时的问题
    classLoader.getResourceAsStream中文乱码
    jQuery与js对象互转
    sqlserver判断字段是否存在更改字段
  • 原文地址:https://www.cnblogs.com/Fosen/p/11953468.html
Copyright © 2011-2022 走看看