zoukankan      html  css  js  c++  java
  • keras训练大量数据的办法

           最近在做一个鉴黄的项目,数据量比较大,有几百个G,一次性加入内存再去训练模青型是不现实的。

    查阅资料发现keras中可以用两种方法解决,一是将数据转为tfrecord,但转换后数据大小会方法不好;另外一种就是利用generator,先一次加入所有数据的路径,然后每个batch的读入

    # 读取图片函数
    def get_im_cv2(paths, img_rows, img_cols, color_type=1, normalize=True):
        '''
        参数:
            paths:要读取的图片路径列表
            img_rows:图片行
            img_cols:图片列
            color_type:图片颜色通道
        返回: 
            imgs: 图片数组
        '''
        # Load as grayscale
        imgs = []
        for path in paths:
            if color_type == 1:
                img = cv2.imread(path, 0)
            elif color_type == 3:
                img = cv2.imread(path)
            # Reduce size
            resized = cv2.resize(img, (img_cols, img_rows))
            if normalize:
                resized = resized.astype('float32')
                resized /= 127.5
                resized -= 1. 
            
            imgs.append(resized)
            
        return np.array(imgs).reshape(len(paths), img_rows, img_cols, color_type)
    def get_train_batch(X_train, y_train, batch_size, img_w, img_h, color_type, is_argumentation):
        '''
        参数:
            X_train:所有图片路径列表
            y_train: 所有图片对应的标签列表
            batch_size:批次
            img_w:图片宽
            img_h:图片高
            color_type:图片类型
            is_argumentation:是否需要数据增强
        返回: 
            一个generator,x: 获取的批次图片 y: 获取的图片对应的标签
        '''
        while 1:
            for i in range(0, len(X_train), batch_size):
                x = get_im_cv2(X_train[i:i+batch_size], img_w, img_h, color_type)
                y = y_train[i:i+batch_size]
                if is_argumentation:
                    # 数据增强
                    x, y = img_augmentation(x, y)
                # 最重要的就是这个yield,它代表返回,返回以后循环还是会继续,然后再返回。就比如有一个机器一直在作累加运算,但是会把每次累加中间结果告诉你一样,直到把所有数加完
                yield(np.array(x}, np.array(y))
    result = model.fit_generator(generator=get_train_batch(X_train, y_train, train_batch_size, img_w, img_h, color_type, True), 
              steps_per_epoch=1351, 
              epochs=50, verbose=1,
              validation_data=get_train_batch(X_valid, y_valid, valid_batch_size,img_w, img_h, color_type, False),
              validation_steps=52,
              callbacks=[ckpt, early_stop],
              max_queue_size=capacity,
              workers=1)

    参考:https://www.jianshu.com/p/5bdae9dcfc9c

              https://keras.io/zh/models/model/

     

  • 相关阅读:
    会话状态服务器解决方法
    让笔记本在插上外置鼠标时触摸板自动关闭
    “检测到有潜在危险的 Request.Form(QueryString) 值”的解决方法
    SQL Server2008不能登录解决方法
    SqlHelper
    修改IE查看源代码编辑器
    由于启动用户实例的进程时出错,导致无法生成 SQL Server 的用户实例解决办法
    性能测试用户模型(二):用户模型图
    索引帖:性能测试新手误区系列
    性能测试用户模型(三):基础数据分析、场景数据
  • 原文地址:https://www.cnblogs.com/573177885qq/p/11984470.html
Copyright © 2011-2022 走看看