zoukankan      html  css  js  c++  java
  • sklearn中,数据集划分函数 StratifiedShuffleSplit.split() 使用踩坑

    在SKLearn中,StratifiedShuffleSplit 类实现了对数据集进行洗牌、分割的功能。但在今晚的实际使用中,发现该类及其方法split()仅能够对二分类样本有效。

    一个简单的例子如下:

     1 import numpy as np
     2 from sklearn.model_selection import StratifiedShuffleSplit
     3 
     4 l4 = np.array([[1,2],[3,4],[1,4],[3,5]])
     5 l5 = np.array([0,1,0,2])
     6 splt = StratifiedShuffleSplit(n_splits=1,test_size=0.5,random_state=1)
     7 for train_idx, valid_idx in splt.split(l4, l5):
     8     print(train_idx,valid_idx)
     9 print('=======')
    10 print(l4[train_idx],l4[valid_idx])
    11 print('=======')
    12 print(l5[train_idx],l5[valid_idx])

    l4 为样本输入列表,l5 为样本输出列表,其中,样本输出(l5)共有3类:[0,1,2] 此时,运行程序会报错:

    ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

     报错信息的字面意思是:我样本输出仅有1类,需要最少2类。但问题是我实际上有3类输出样本。这个问题百度了半天也没找到合适的解答。

    后面将3类样本改为2类,该函数就能正常运行了。

     1 import numpy as np
     2 from sklearn.model_selection import StratifiedShuffleSplit
     3 
     4 l4 = np.array([[1,2],[3,4],[1,4],[3,5]])
     5 l5 = np.array([0,1,0,1])
     6 splt = StratifiedShuffleSplit(n_splits=1,test_size=0.5,random_state=1)
     7 for train_idx, valid_idx in splt.split(l4, l5):
     8     print(train_idx,valid_idx)
     9 print('=======')
    10 print(l4[train_idx],l4[valid_idx])
    11 print('=======')
    12 print(l5[train_idx],l5[valid_idx])

    注意,在上方代码第5行,将 l5 的值进行修改,样本输出仅有[0,1]两类。

    此时运行程序,运行无误。

     StratifiedShuffleSplit.split() 函数对于多分类问题还是无法正确适配。

  • 相关阅读:
    Saltstack module gem 详解
    Saltstack module freezer 详解
    Saltstack module firewalld 详解
    Saltstack module file 详解
    Saltstack module event 详解
    Saltstack module etcd 详解
    Saltstack module environ 详解
    Saltstack module drbd 详解
    Saltstack module dnsutil 详解
    获取主页_剥离百度
  • 原文地址:https://www.cnblogs.com/NosenLiu/p/14820156.html
Copyright © 2011-2022 走看看