zoukankan      html  css  js  c++  java
  • 不平衡学习方法理论和实战总结

    原文:http://blog.csdn.net/hero_fantao/article/details/35784773

    不平衡学习方法

    机器学习中样本不平衡问题大致分为两方面:

    (1)类别中样本比率不平衡,但是几个类别的样本都足够多;

    (2)类别中某类样本较少。

    对第二个问题,其实不是我们重点,因为样本不足的话,覆盖空间是很小,如果特征足够多的话,这种数据对模型学习的价值也不大,所以,对这个问题,好的方法只能是找尽量多的小类样本来覆盖样本空间。

    现在主要讨论第一个问题。

    一: 采样方法

    1. 随机重采样(random oversampling):

      样本不平衡时候,对小类样本就行随机重采样,以达到平衡。这种方法只是对小类样本进行简单的拷贝,缺点是容易over-fit,比如在决策树分类的时候,很有可能一个终端叶子节点的样本都是一个样本的拷贝而已,扩展性不足,这可能会提高模型训练的精度,但是对未知测试样本的预测可能是很差的。

       

    2. 随机欠采样(random oversampling):

         样本不平衡时候,对大类样本就行随机欠采样,就是取部分大类样本,以达到平衡。欠采样的问题是对样本减少可能会缺失样本空间中重要数据,降低准确性。

    3. Synthetic Sampling with Data Generation

      对小类样本进行近似数据样本生成。对小类样本计算KNN,找出K个相近样本,根据K近邻样本于当前样本的距离,生成新的样本。

     

     这种方法突破了原有的简单的重复采样的方法,通过创建新的小样本,丰富了小样本的样本空间,弥补了小样本样本空间不足的问题。缺点是它对所有的小类样本都计算相同的KNN。试想下对于那些和大类样本有明显的区分度的小样本,对于这些产生多余的样本价值不大。

    4. Adaptive Synthetic Sampling

      Adaptive Synthetic Sampling是一种修正方法,他试图增加小样本中和大类样本比较相近的样本sampling。

    方法如下:

    二 代价学习方法

    一是从样本角度来看,尽量做到样本平衡,然后来用模型的学习。还有种就是通过设置不同样本误判的代价,比如设置小样本误判的代价大一些。个人的感觉,这种方法和一中重采样的效果差不多,牺牲一个换取另外一个。个人觉得一种好的方法是,正负样本不平衡时候,每次选取一部分大类样本和全部小样本,尽量平衡,训练一个模型。重复以上操作,训练得到若干模型,把这些模型做个voting,获得最终预测结果,可以效仿Adaboost,对每个模型进行加权。其实,voting的方法就能达到很不多的效果。

    参考文献:

    [1] He H, Garcia E A. Learning from imbalanced data[J]. Knowledge and Data Engineering, IEEE Transactions on, 2009, 21(9): 1263-1284.

    [2] https://github.com/fmfn/UnbalancedDataset(2014/12/07 @phunter_lau分享的一个模块)

    附上Adaptive Synthetic Sampling源码:

    [python] view plaincopy在CODE上查看代码片派生到我的代码片
     
      1. ''''' 
      2. Created on 2014/03/09 
      3.  
      4. @author: dylan 
      5. '''  
      6. from sklearn.neighbors import NearestNeighbors  
      7. import numpy as np  
      8. import random  
      9.   
      10.   
      11.   
      12. def get_class_count(y, minorityclasslabel = 1):  
      13.     minorityclasslabel_count = len(np.where(y == minorityclasslabel)[0])  
      14.     maxclasslabel_count = len(np.where(y == (1 - minorityclasslabel))[0])  
      15.       
      16.     return maxclasslabel_count, minorityclasslabel_count  
      17.       
      18.       
      19. # @param: X The datapoints e.g.: [f1, f2, ... ,fn]  
      20. # @param: y the classlabels e.g: [0,1,1,1,0,...,Cn]  
      21. # @param ms: The amount of samples in the minority group  
      22. # @param ml: The amount of samples in the majority group  
      23. # @return: the G value, which indicates how many samples should be generated in total, this can be tuned with beta  
      24. def getG(ml, ms, beta):  
      25.     return (ml-ms)*beta  
      26.   
      27.   
      28. # @param: X The datapoints e.g.: [f1, f2, ... ,fn]  
      29. # @param: y the classlabels e.g: [0,1,1,1,0,...,Cn]  
      30. # @param: minorityclass: The minority class  
      31. # @param: K: The amount of neighbours for Knn  
      32. # @return: rlist: List of r values  
      33. def getRis(X,y,indicesMinority,minorityclasslabel,K):      
      34.       
      35.     ymin = np.array(y)[indicesMinority]  
      36.     Xmin = np.array(X)[indicesMinority]  
      37.     neigh = NearestNeighbors(n_neighbors= K)  
      38.     neigh.fit(X)  
      39.       
      40.     rlist = [0]*len(ymin)  
      41.     normalizedrlist = [0]*len(ymin)  
      42.       
      43.     for i in xrange(len(ymin)):  
      44.         indices = neigh.kneighbors(Xmin[i],K,False)[0]  
      45. #         print'y[indices] == (1 - minorityclasslabel):'  
      46. #         print y[indices]  
      47. #         print len(np.where(y[indices] == ( 1- minorityclasslabel))[0])  
      48.         rlist[i] = len(np.where(y[indices] == ( 1- minorityclasslabel))[0])/(K + 0.0)  
      49.           
      50.     normConst = sum(rlist)  
      51.   
      52.     for j in xrange(len(rlist)):  
      53.         normalizedrlist[j] = (rlist[j]/normConst)  
      54.   
      55.     return normalizedrlist  
      56.   
      57. def get_indicesMinority(y, minorityclasslabel = 1):  
      58.     y_new = []  
      59.     for i in range(len(y)):  
      60.         if y[i] == 1:  
      61.             y_new.append(1)  
      62.         else:  
      63.             y_new.append(0)  
      64.     y_new = np.asarray(y_new)  
      65.     indicesMinority = np.where(y_new == minorityclasslabel)[0]   
      66.          
      67.     return indicesMinority, y_new  
      68.   
      69. def generateSamples(X, y, minorityclasslabel = 1, K =5,beta = 0.3):  
      70.     syntheticdata_X = []  
      71.     syntheticdata_y = []  
      72.       
      73.       
      74.     indicesMinority, y_new = get_indicesMinority(y)  
      75.     ymin = y[indicesMinority]  
      76.     Xmin = X[indicesMinority]  
      77.       
      78.     rlist = getRis(X, y_new, indicesMinority, minorityclasslabel, K)  
      79.     ml, ms = get_class_count(y_new)  
      80.     G = getG(ml,ms, beta = beta)  
      81.      
      82.     neigh = NearestNeighbors(n_neighbors=K)  
      83.     neigh.fit(Xmin)  
      84.       
      85.     for k in xrange(len(ymin)):  
      86.         g = int(np.round(rlist[k]*G))  
      87.   
      88.         neighb_indx = neigh.kneighbors(Xmin[k],K,False)[0]  
      89.               
      90.         for l in xrange(g):  
      91.             ind = random.choice(neighb_indx)  
      92.             s = Xmin[k] + (Xmin[ind]-Xmin[k]) * random.random()  
      93.             syntheticdata_X.append(s)  
      94.             syntheticdata_y.append(ymin[k])  
      95.               
      96.     print 'asyn, raw X size:',X.shape          
      97.     X = np.vstack((X,np.asarray(syntheticdata_X)))  
      98.      
      99.     y = np.hstack((y,syntheticdata_y))  
      100.     print 'asyn, post X size:',X.shape  
      101.       
      102.     return X , y  
      103.      
  • 相关阅读:
    LightOJ 1094
    hdu 2586
    hdu 5234
    hdu 2955
    LightOJ 1030 数学期望
    poj 1273
    CodeIgniter学习笔记(十五)——CI中的Session
    CodeIgniter学习笔记(十四)——CI中的文件上传
    CodeIgniter学习笔记(十三)——CI中的分页
    CodeIgniter学习笔记(十二)——CI中的路由
  • 原文地址:https://www.cnblogs.com/zhizhan/p/5042922.html
Copyright © 2011-2022 走看看