zoukankan      html  css  js  c++  java
  • sklearn中,数据集划分函数 StratifiedShuffleSplit.split() 使用踩坑

    在SKLearn中,StratifiedShuffleSplit 类实现了对数据集进行洗牌、分割的功能。但在今晚的实际使用中,发现该类及其方法split()仅能够对二分类样本有效。

    一个简单的例子如下:

     1 import numpy as np
     2 from sklearn.model_selection import StratifiedShuffleSplit
     3 
     4 l4 = np.array([[1,2],[3,4],[1,4],[3,5]])
     5 l5 = np.array([0,1,0,2])
     6 splt = StratifiedShuffleSplit(n_splits=1,test_size=0.5,random_state=1)
     7 for train_idx, valid_idx in splt.split(l4, l5):
     8     print(train_idx,valid_idx)
     9 print('=======')
    10 print(l4[train_idx],l4[valid_idx])
    11 print('=======')
    12 print(l5[train_idx],l5[valid_idx])

    l4 为样本输入列表,l5 为样本输出列表,其中,样本输出(l5)共有3类:[0,1,2] 此时,运行程序会报错:

    ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

     报错信息的字面意思是:我样本输出仅有1类,需要最少2类。但问题是我实际上有3类输出样本。这个问题百度了半天也没找到合适的解答。

    后面将3类样本改为2类,该函数就能正常运行了。

     1 import numpy as np
     2 from sklearn.model_selection import StratifiedShuffleSplit
     3 
     4 l4 = np.array([[1,2],[3,4],[1,4],[3,5]])
     5 l5 = np.array([0,1,0,1])
     6 splt = StratifiedShuffleSplit(n_splits=1,test_size=0.5,random_state=1)
     7 for train_idx, valid_idx in splt.split(l4, l5):
     8     print(train_idx,valid_idx)
     9 print('=======')
    10 print(l4[train_idx],l4[valid_idx])
    11 print('=======')
    12 print(l5[train_idx],l5[valid_idx])

    注意,在上方代码第5行,将 l5 的值进行修改,样本输出仅有[0,1]两类。

    此时运行程序,运行无误。

     StratifiedShuffleSplit.split() 函数对于多分类问题还是无法正确适配。

  • 相关阅读:
    身份证相关类
    微信开发相关文档
    password、文件MD5加密,passwordsha256、sha384、sha512Hex等加密
    图的割点(边表集实现)
    动态库DLL中类的使用
    吴恩达机器学习笔记_第三周
    Android官方开发文档Training系列课程中文版:性能优化建议
    简单算法汇总
    Gson解析第三方提供Json数据(天气预报,新闻等)
    Java字节码 小结
  • 原文地址:https://www.cnblogs.com/NosenLiu/p/14820156.html
Copyright © 2011-2022 走看看