zoukankan      html  css  js  c++  java
  • Keras 使用自己编写的数据生成器

    使用自己编写的数据生成器,配合keras的fit_generator训练模型

    注意:模型结构要和生成器生成数据的尺寸要对应,txt存的数据路径一般是有序的,想办法打乱它

    # 以下部分代码,仅做示意
    ……
    def gen_mine():
        txtpath = './2.txt' # 数据路径存在txt
        data_train = []
        data_labels = []
        cnt = 0 # 用于批量计数
        for n in open(txtpath):
            img = cv2.imread(n[:-1]) # 最后一个字节是换行符,去掉它
            img_64 = cv2.resize(img,(64,64)) # 输入到模型前要统一尺寸
            img_rgb = img_64[:,:,::-1] # cv读的数据是bgr,这里改成标准的rgb
            if n.split('/')[1] == 'file_N': # 由于我是根据文件夹的名字定的标签,这个看自己的需求
                label = [0,1,0] # 注意要写成独热编码的形式
            else:
                label = [1,0,0]
            data_train.append(img_rgb)
            data_labels.append(label)
            cnt = cnt + 1
            if cnt == BS:
                cnt = 0 # 初始化
                data_train = np.array(data_train)
                data_labels = np.array(data_labels)
                print(data_train.shape, data_labels.shape)
                yield (data_train, data_labels)
                data_train = [] # 初始化
                data_labels = []
    ……
    model.fit_generator(gen_mine(),steps_per_epoch=steps_per_epoch_, epochs=NUM_EPOCHS, class_weight = 'auto', max_queue_size=1,workers=1)
    
  • 相关阅读:
    Python的正则表达式
    Python的异常处理
    Python的类和对象
    Python乘法口诀表
    Python的文件操作
    三层架构介绍和MVC设计模型介绍
    spring的组件使用
    IDEA使用maven搭建spring项目
    Java集合——Collection接口
    Java集合——概述
  • 原文地址:https://www.cnblogs.com/niulang/p/13522487.html
Copyright © 2011-2022 走看看