zoukankan      html  css  js  c++  java
  • 基于sklearn的数据集划分

    https://blog.csdn.net/wade1203/article/details/91453804

    from sklearn.model_selection import train_test_split

    函数 train_test_split(data, label, test_size = 0.3, random_state =2020 )

    参数解释:

       data: 待划分样本数据(list形式)

       label: 待划分样本数据的标签(list形式)

      test_size: 测试数据占样本数据的比例,若整数则是测试数据量,这里设置为0.3,则训练测试比例:0.7:0.3

      random_state: 设置随机数种子,保证每次都是同一个随机数。如果为0或者不填,则每次得到数据都不一样

    代码示例(使用时没有通过函数直接split 标签文件,这个可以自行修改的):

    将image_10000  和  txt_10000文件夹按照7:3的比例分离出  train/img; train/gt 和 test/img; test/gt文件夹,原始文件夹被覆盖

    from sklearn.model_selection import train_test_split
    import os
    import sys
    import pathlib
    from glob import glob
    from PIL import Image
    import shutil
    
    if __name__ =='__main__':
        __dir__ = pathlib.Path(os.path.abspath(__file__))
        sys.path.append(str(__dir__))
        sys.path.append(str(__dir__.parent))
        pth_img = './ICPR_text_train_part2_20180313/image_10000/'
        pth_txt = './ICPR_text_train_part2_20180313/txt_10000/'
        pth_new_tmp='./ICPR_text_train_part2_20180313/img/'
        pth_new_tmp1 = './ICPR_text_train_part2_20180313/gt/'
        pth_new = './ICPR_text_train_part2_20180313/test'
        if not os.path.exists(pth_new):
            os.mkdir(pth_new)
        #img_test_pth = os.path.join(pth_new,'img')
        gt_test_pth = os.path.join(pth_new,'gt')
        # if not os.path.exists(img_test_pth):
        #     os.mkdir(img_test_pth)
        # if not os.path.exists(gt_test_pth):
        #     os.mkdir(gt_test_pth)
        pth_1 = './ICPR_text_train_part2_20180313/train'
        if not os.path.exists(pth_1):
            os.mkdir(pth_1)
        img_train_pth = os.path.join(pth_1,'img')
        gt_train_pth = os.path.join(pth_1,'gt')
        if not os.path.exists(img_train_pth):
            os.mkdir(img_train_pth)
        if not os.path.exists(gt_train_pth):
            os.mkdir(gt_train_pth)
        # gif2jpg(pth)
        files = [img for img in os.listdir(pth_img) if img.endswith('jpg')]
        train, test = train_test_split(files,test_size=0.3,random_state=2020)
        #train:img gt    test: img gt
        print('train:{} images,test:{} images'.format(len(train),len(test)))
        i=0
        for line in train:
            ori_pth_img = pth_img+line
            line_txt = os.path.splitext(line)[0]+'.txt'
            ori_pth_txt = pth_txt+line_txt
            des_pth_img = os.path.join(img_train_pth,line)
            des_pth_txt = os.path.join(gt_train_pth,line_txt)
            shutil.move(ori_pth_img,des_pth_img)
            shutil.move(ori_pth_txt,des_pth_txt)
            i=i+1
        print('move {} imgs totally'.format(i))
        os.rename(pth_img,pth_new_tmp)
        os.rename(pth_txt,pth_new_tmp1)
        shutil.move(pth_new_tmp,pth_new)
        shutil.move(pth_new_tmp1,pth_new)
  • 相关阅读:
    Oracle 11g db_ultra_safe参数
    How To Configure NTP On Windows 2008 R2 (zt)
    Brocade光纤交换机密码重置 (ZT)
    perl如何访问Oracle (ZT)
    Nagios check_nrpe : Socket timeout after 10 seconds
    oracle10g单机使用ASM存储数据
    Xmanager无法连接Solaris10 (ZT)
    Solaris10配置iscsi initiator
    oracle 11g dataguard 创建过程
    Nagios check_procs pst3 报错
  • 原文地址:https://www.cnblogs.com/zzc-Andy/p/15090011.html
Copyright © 2011-2022 走看看