zoukankan      html  css  js  c++  java
  • python 多分类任务中按照类别分层采样

    在机器学习多分类任务中有时候需要针对类别进行分层采样,比如说类别不均衡的数据,这时候随机采样会造成训练集、验证集、测试集中不同类别的数据比例不一样,这是会在一定程度上影响分类器的性能的,这时候就需要进行分层采样保证训练集、验证集、测试集中每一个类别的数据比例差不多持平。

    下面python代码。

    # 将数据按照类别进行分层划分
    def save_file_stratified(filename, ssdfile_dir, categories):
        """
        将文件分流到3个文件中
        filename: 原数据地址,一个csv文件
        文件内容格式:  类别	内容
        """
        f_train = open('../data/usefuldata-711depart/train.txt', 'w', encoding='utf-8')
        f_val = open('../data/usefuldata-711depart/val.txt', 'w', encoding='utf-8')
        f_test = open('../data/usefuldata-711depart/test.txt', 'w', encoding='utf-8')
        # f_class = open('../data/usefuldata-37depart/class.txt', 'w', encoding='utf-8')
        dict_ssdqw = {}
        for ssdfile in os.listdir(ssdfile_dir):
            ssdfile_name = os.path.join(ssdfile_dir, ssdfile)
            f = open(ssdfile_name, 'r', encoding='utf-8')
            content_qw = ''
            content = f.readline()
            # 以下部分,因为统计整个案件基本情况他有换行,所以将多行处理在一行里面
            while content:
                content_qw += content
                content_qw = content_qw.replace('
    ', '')
                content = f.readline()
            ssdfile_key = str(ssdfile).replace('.txt','')
            dict_ssdqw[ssdfile_key] = content_qw
        # doc_count代表每一类数据总共有多少个
        doc_count_0 = 0
        doc_count_1 = 0
        doc_count_2 = 0
        doc_count_3 = 0
        doc_count_4 = 0
        doc_count_5 = 0
        doc_count_6 = 0
        doc_count_7 = 0
        doc_count_8 = 0
        doc_count_9 = 0
        doc_count_10 = 0
        doc_count_11 = 0
        doc_count_12 = 0
        temp_file = open(filename, 'r', encoding='utf-8')
        line = temp_file.readline()
        while line:
            line_content = line.split(',')
            name = line_content[0]
            if name in dict_ssdqw:
                label = line_content[1]
                if label == categories[0]:
                    doc_count_0 += 1
                elif label == categories[1]:
                    doc_count_1 += 1
                elif label == categories[2]:
                    doc_count_2 += 1
                elif label == categories[3]:
                    doc_count_3 += 1
                elif label == categories[4]:
                    doc_count_4 += 1
                elif label == categories[5]:
                    doc_count_5 += 1
                elif label == categories[6]:
                    doc_count_6 += 1
                elif label == categories[7]:
                    doc_count_7 += 1
                elif label == categories[8]:
                    doc_count_8 += 1
                elif label == categories[9]:
                    doc_count_9 += 1
                elif label == categories[10]:
                    doc_count_10 += 1
                elif label == categories[11]:
                    doc_count_11 += 1
                elif label == categories[12]:
                    doc_count_12 += 1
            line = temp_file.readline()
        temp_file.close()
        # 总数量
        doc_count = doc_count_0 + doc_count_1 + doc_count_2 + doc_count_3 +
            doc_count_4 + doc_count_5 + doc_count_6 + doc_count_7 +
            doc_count_8 + doc_count_9 + doc_count_10 + doc_count_11 + doc_count_12
        class_set = set()
        tag_train_0 = doc_count_0 * 70 / 100
        tag_train_1 = doc_count_1 * 70 / 100
        tag_train_2 = doc_count_2 * 70 / 100
        tag_train_3 = doc_count_3 * 70 / 100
        tag_train_4 = doc_count_4 * 70 / 100
        tag_train_5 = doc_count_5 * 70 / 100
        tag_train_6 = doc_count_6 * 70 / 100
        tag_train_7 = doc_count_7 * 70 / 100
        tag_train_8 = doc_count_8 * 70 / 100
        tag_train_9 = doc_count_9 * 70 / 100
        tag_train_10 = doc_count_10 * 70 / 100
        tag_train_11= doc_count_11 * 70 / 100
        tag_train_12 = doc_count_12 * 70 / 100
        tag_val_0 = doc_count_0 * 85 / 100
        tag_val_1 = doc_count_1 * 85 / 100
        tag_val_2 = doc_count_2 * 85 / 100
        tag_val_3 = doc_count_3 * 85 / 100
        tag_val_4 = doc_count_4 * 85 / 100
        tag_val_5 = doc_count_5 * 85 / 100
        tag_val_6 = doc_count_6 * 85 / 100
        tag_val_7 = doc_count_7 * 85 / 100
        tag_val_8 = doc_count_8 * 85 / 100
        tag_val_9 = doc_count_9 * 85 / 100
        tag_val_10 = doc_count_10 * 85 / 100
        tag_val_11 = doc_count_11 * 85 / 100
        tag_val_12 = doc_count_12 * 85 / 100
        # tag_test = doc_count * 70 / 100
        tag_0 = 0
        tag_1 = 0
        tag_2 = 0
        tag_3 = 0
        tag_4 = 0
        tag_5 = 0
        tag_6 = 0
        tag_7 = 0
        tag_8 = 0
        tag_9 = 0
        tag_10 = 0
        tag_11 = 0
        tag_12 = 0
        # 有些文书行业标记是空!!我想看看有多少条?
        blank_tag = 0
        # 标记一下,每个类别有多少个训练集、验证集、测试集?
        train_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        val_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        test_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        # csvfile = open(filename, 'r', encoding='utf-8')
        txtfile = open(filename, 'r', encoding='utf-8')
        process_line = txtfile.readline()
        while process_line:
            line_content = process_line.split(',')
            name = line_content[0]
            if name in dict_ssdqw:
                content = dict_ssdqw[name]
                label = line_content[1]
                # if label != '' and label != '其他行业':
                if label != '':
                    class_set.add(label)
                    # 对每一类进行分层采样
                    if label == categories[0]:
                        tag_0 += 1
                        if tag_0 < tag_train_0:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[0] += 1
                        elif tag_0 < tag_val_0:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[0] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[0] += 1
                    elif label == categories[1]:
                        tag_1 += 1
                        if tag_1 < tag_train_1:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[1] += 1
                        elif tag_1 < tag_val_1:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[1] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[1] += 1
                    elif label == categories[2]:
                        tag_2 += 1
                        if tag_2 < tag_train_2:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[2] += 1
                        elif tag_2 < tag_val_2:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[2] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[2] += 1
                    elif label == categories[3]:
                        tag_3 += 1
                        if tag_3 < tag_train_3:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[3] += 1
                        elif tag_3 < tag_val_3:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[3] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[3] += 1
                    elif label == categories[4]:
                        tag_4 += 1
                        if tag_4 < tag_train_4:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[4] += 1
                        elif tag_4 < tag_val_4:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[4] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[4] += 1
                    elif label == categories[5]:
                        tag_5 += 1
                        if tag_5 < tag_train_5:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[5] += 1
                        elif tag_5 < tag_val_5:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[5] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[5] += 1
                    elif label == categories[6]:
                        tag_6 += 1
                        if tag_6 < tag_train_6:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[6] += 1
                        elif tag_6 < tag_val_6:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[6] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[6] += 1
                    elif label == categories[7]:
                        tag_7 += 1
                        if tag_7 < tag_train_7:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[7] += 1
                        elif tag_7 < tag_val_7:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[7] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[7] += 1
                    elif label == categories[8]:
                        tag_8 += 1
                        if tag_8 < tag_train_8:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[8] += 1
                        elif tag_8 < tag_val_8:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[8] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[8] += 1
                    elif label == categories[9]:
                        tag_9 += 1
                        if tag_9 < tag_train_9:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[9] += 1
                        elif tag_9 < tag_val_9:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[9] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[9] += 1
                    elif label == categories[10]:
                        tag_10 += 1
                        if tag_10 < tag_train_10:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[10] += 1
                        elif tag_10 < tag_val_10:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[10] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[10] += 1
                    elif label == categories[11]:
                        tag_11 += 1
                        if tag_11 < tag_train_11:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[11] += 1
                        elif tag_11 < tag_val_11:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[11] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[11] += 1
                    elif label == categories[12]:
                        tag_12 += 1
                        if tag_12 < tag_train_12:
                            f_train.write(label + '	' + content + '
    ')
                            train_class_tag[12] += 1
                        elif tag_12 < tag_val_12:
                            f_val.write(label + '	' + content + '
    ')
                            val_class_tag[12] += 1
                        else:
                            f_test.write(label + '	' + content + '
    ')
                            test_class_tag[12] += 1
                else:
                    blank_tag += 1
            process_line = txtfile.readline()
        txtfile.close()
        print("" + str(blank_tag) + "个文书的行业标记为空!")
        print("train:")
        print(train_class_tag)
        train_tag_total =0
        for i_total in train_class_tag:
            train_tag_total += i_total
        train_class_tag_distribute = []
        for i in train_class_tag:
            train_class_tag_distribute.append((i / train_tag_total) * 100)
        print("分布:")
        print(train_class_tag_distribute)
        print("val:")
        print(val_class_tag)
        val_tag_total = 0
        for i_total in val_class_tag:
            val_tag_total += i_total
        val_class_tag_distribute = []
        for i in val_class_tag:
            val_class_tag_distribute.append((i / val_tag_total) * 100)
        print("分布:")
        print(val_class_tag_distribute)
        print("test:")
        print(test_class_tag)
        test_tag_total = 0
        for i_total in test_class_tag:
            test_tag_total += i_total
        test_class_tag_distribute = []
        for i in test_class_tag:
            test_class_tag_distribute.append((i / test_tag_total) * 100)
        print("分布:")
        print(test_class_tag_distribute)
        f_train.close()
        f_test.close()
        f_val.close()
    if __name__ == '__main__':
        categories = [
            "class1",
            "class2",
            "class3",
            "class4",
            "class5",
            "class6",
            "class7",
            "class8",
            "class9",
            "class10",
            "class11",
            "class12",
            "class13"
        ]
        save_file_stratified('../data/qwdata/shuffle-try3/classified_table_ms.txt', '../data/qwdata/ms-ygscplusssdqw',categories)
    View Code

    后面可以看到类别划分


    这里要注意的一点是:这是我早期写的文章,需要注意的一点是,我们通常在训练集和验证集上做分层采样即可,测试集最好保持原样不要动。

  • 相关阅读:
    DNS部署与安全
    DHCP部署与安全
    jenkins漏洞复现
    Apache Axis2 漏洞复现
    制作war包
    JBOOS 漏洞复现
    Tomcat漏洞复现
    编写登陆接口(2)
    学习使用新工具Pycharm
    while练习99乘法表
  • 原文地址:https://www.cnblogs.com/zhouxiaosong/p/11113959.html
Copyright © 2011-2022 走看看