zoukankan      html  css  js  c++  java
  • keras创建自己训练代码

    由于某个github只开源了测试代码,所以训练代码需要自己写

    版本keras,tensorflow

    # import src.modelCore as modelCore
    from src.modelCore import create_model
    from keras.optimizers import SGD
    from keras.preprocessing.image import ImageDataGenerator
    import tensorflow as tf
    import keras
    from keras.callbacks import ModelCheckpoint
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True      # TensorFlow按需分配显存
    config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 指定显存分配比例
    keras.backend.tensorflow_backend.set_session(tf.Session(config=config))
    
    
    # 加载模型
    def load_pretrain_model_by_index(pretrain_index):
        if pretrain_index == 4:
            IMC_model_idx, freeze_featex, window_size_list = 2, False, [7, 15, 31]
        else:
            IMC_model_idx, freeze_featex, window_size_list = pretrain_index, False, [7, 15, 31, 63]
        single_gpu_model = create_model(IMC_model_idx, freeze_featex, window_size_list)
        # weight_file = "{}/ManTraNet_Ptrain{}.h5".format(model_dir, pretrain_index )
        # assert os.path.isfile(weight_file), "ERROR: fail to locate the pretrained weight file"
        # single_gpu_model.load_weights( weight_file )
        return single_gpu_model
    
    
    def trainGenerator(batch_size, train_path, image_folder, mask_folder, aug_dict, image_color_mode="rgb",
                       mask_color_mode="grayscale", image_save_prefix="image", mask_save_prefix="mask",
                       flag_multi_class=False, num_class=2, save_to_dir=None, target_size=(256, 256), seed=1):
        '''
        can generate image and mask at the same time
        use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
        if you want to visualize the results of generator, set save_to_dir = "your path"
        '''
        image_datagen = ImageDataGenerator(**aug_dict)
        mask_datagen = ImageDataGenerator(**aug_dict)
        image_generator = image_datagen.flow_from_directory(
            train_path,  # 训练数据文件夹路径
            classes=[image_folder],  # 类别文件夹,对哪一个类进行增强
            class_mode=None,  # 不返回标签
            color_mode=image_color_mode,  # 灰度,单通道模式
            # target_size=target_size,  # 转换后的目标图片大小
            batch_size=batch_size,  # 每次产生的(进行转换的)图片张数
            save_to_dir=save_to_dir,  # 保存的图片路径
            save_prefix=image_save_prefix,  # 生成图片的前缀,仅当提供save_to_dir时有效
            seed=seed)
        mask_generator = mask_datagen.flow_from_directory(
            train_path,
            classes=[mask_folder],
            class_mode=None,
            color_mode=mask_color_mode,
            # target_size=target_size,
            batch_size=batch_size,
            save_to_dir=save_to_dir,
            save_prefix=mask_save_prefix,
            seed=seed)
        train_generator = zip(image_generator, mask_generator)  # 组合成一个生成器
        for (img, mask) in train_generator:
            # 由于batch是2,所以一次返回两张,即img是一个2张灰度图片的数组,[2,256,256]
            # img, mask = adjustData(img, mask, flag_multi_class, num_class)  # 返回的img依旧是[2,256,256]
            yield (img, mask)
    
    
    manTraNet = load_pretrain_model_by_index(4)
    sgd = SGD(0.01, 0, 1e-6)
    manTraNet.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
    train_path = r"C:UsersDNY-006Desktops2_datas2_datadata"
    mask_path = r"C:UsersDNY-006Desktops2_datas2_datadata	rain_mask"
    # img_train, mask_train = geneTrainNpy(train_path, mask_path)
    data_gen_args = dict(rotation_range=0.2, #整数。随机旋转的度数范围。
                        width_shift_range=0.05, #浮点数、一维数组或整数
                        height_shift_range=0.05, #浮点数。剪切强度(以弧度逆时针方向剪切角度)。
                        shear_range=0.05,
                        zoom_range=0.05, #浮点数 或 [lower, upper]。随机缩放范围
                        horizontal_flip=True,
                        fill_mode='nearest')
    train_generator = trainGenerator(1,train_path,'1111','train_mask',data_gen_args,save_to_dir = None)
    # 保存训练的模型参数到指定的文件夹,格式为.hdf5; 检测的值是'loss'使其更小。
    model_checkpoint = ModelCheckpoint('ManTraNet_owndata.hdf5', monitor='loss',verbose=1, save_best_only=True)
    # manTraNet.fit(img_train, mask_train, epochs=50, batch_size=32, shuffle=True, verbose=1, validation_split=0.3)
    manTraNet.fit_generator(train_generator,steps_per_epoch=1000,epochs=60,callbacks=[model_checkpoint])#validation_data=validation_generator, validation_steps=200)# import src.modelCore as modelCore
    from src.modelCore import create_model
    from keras.optimizers import SGD
    from keras.preprocessing.image import ImageDataGenerator
    import tensorflow as tf
    import keras
    from keras.callbacks import ModelCheckpoint
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True      # TensorFlow按需分配显存
    config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 指定显存分配比例
    keras.backend.tensorflow_backend.set_session(tf.Session(config=config))
    
    
    # 加载模型
    def load_pretrain_model_by_index(pretrain_index):
        if pretrain_index == 4:
            IMC_model_idx, freeze_featex, window_size_list = 2, False, [7, 15, 31]
        else:
            IMC_model_idx, freeze_featex, window_size_list = pretrain_index, False, [7, 15, 31, 63]
        single_gpu_model = create_model(IMC_model_idx, freeze_featex, window_size_list)
        # weight_file = "{}/ManTraNet_Ptrain{}.h5".format(model_dir, pretrain_index )
        # assert os.path.isfile(weight_file), "ERROR: fail to locate the pretrained weight file"
        # single_gpu_model.load_weights( weight_file )
        return single_gpu_model
    
    
    def trainGenerator(batch_size, train_path, image_folder, mask_folder, aug_dict, image_color_mode="rgb",
                       mask_color_mode="grayscale", image_save_prefix="image", mask_save_prefix="mask",
                       flag_multi_class=False, num_class=2, save_to_dir=None, target_size=(256, 256), seed=1):
        '''
        can generate image and mask at the same time
        use the same seed for image_datagen and mask_datagen to ensure the transformation for image and mask is the same
        if you want to visualize the results of generator, set save_to_dir = "your path"
        '''
        image_datagen = ImageDataGenerator(**aug_dict)
        mask_datagen = ImageDataGenerator(**aug_dict)
        image_generator = image_datagen.flow_from_directory(
            train_path,  # 训练数据文件夹路径
            classes=[image_folder],  # 类别文件夹,对哪一个类进行增强
            class_mode=None,  # 不返回标签
            color_mode=image_color_mode,  # 灰度,单通道模式
            # target_size=target_size,  # 转换后的目标图片大小
            batch_size=batch_size,  # 每次产生的(进行转换的)图片张数
            save_to_dir=save_to_dir,  # 保存的图片路径
            save_prefix=image_save_prefix,  # 生成图片的前缀,仅当提供save_to_dir时有效
            seed=seed)
        mask_generator = mask_datagen.flow_from_directory(
            train_path,
            classes=[mask_folder],
            class_mode=None,
            color_mode=mask_color_mode,
            # target_size=target_size,
            batch_size=batch_size,
            save_to_dir=save_to_dir,
            save_prefix=mask_save_prefix,
            seed=seed)
        train_generator = zip(image_generator, mask_generator)  # 组合成一个生成器
        for (img, mask) in train_generator:
            # 由于batch是2,所以一次返回两张,即img是一个2张灰度图片的数组,[2,256,256]
            # img, mask = adjustData(img, mask, flag_multi_class, num_class)  # 返回的img依旧是[2,256,256]
            yield (img, mask)
    
    
    manTraNet = load_pretrain_model_by_index(4)
    sgd = SGD(0.01, 0, 1e-6)
    manTraNet.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])
    train_path = r"C:UsersDNY-006Desktops2_datas2_datadata"
    mask_path = r"C:UsersDNY-006Desktops2_datas2_datadata	rain_mask"
    # img_train, mask_train = geneTrainNpy(train_path, mask_path)
    data_gen_args = dict(rotation_range=0.2, #整数。随机旋转的度数范围。
                        width_shift_range=0.05, #浮点数、一维数组或整数
                        height_shift_range=0.05, #浮点数。剪切强度(以弧度逆时针方向剪切角度)。
                        shear_range=0.05,
                        zoom_range=0.05, #浮点数 或 [lower, upper]。随机缩放范围
                        horizontal_flip=True,
                        fill_mode='nearest')
    train_generator = trainGenerator(1,train_path,'1111','train_mask',data_gen_args,save_to_dir = None)
    # 保存训练的模型参数到指定的文件夹,格式为.hdf5; 检测的值是'loss'使其更小。
    model_checkpoint = ModelCheckpoint('ManTraNet_owndata.hdf5', monitor='loss',verbose=1, save_best_only=True)
    # manTraNet.fit(img_train, mask_train, epochs=50, batch_size=32, shuffle=True, verbose=1, validation_split=0.3)
    manTraNet.fit_generator(train_generator,steps_per_epoch=1000,epochs=60,callbacks=[model_checkpoint])#validation_data=validation_generator, validation_steps=200)

    参考

    https://blog.csdn.net/Xnion/article/details/105797671

  • 相关阅读:
    Linux修改主机名称方法
    高精度模板(含加减乘除四则运算)
    背包问题(0-1背包,完全背包,多重背包知识概念详解)
    [Swust OJ 385]--自动写诗
    [Swust OJ 403]--集合删数
    [Swust OJ 409]--小鼠迷宫问题(BFS+记忆化搜索)
    [Swust OJ 360]--加分二叉树(区间dp)
    [Swust OJ 402]--皇宫看守(树形dp)
    [Swust OJ 581]--彩色的石子(状压dp)
    [Swust OJ 589]--吃西瓜(三维矩阵压缩)
  • 原文地址:https://www.cnblogs.com/bob-jianfeng/p/13840177.html
Copyright © 2011-2022 走看看