zoukankan      html  css  js  c++  java
  • Weka中EM算法详解

     1  private void EM_Init (Instances inst)
     2     throws Exception {
     3     int i, j, k;
     4 
     5     // 由于EM算法对初始值较敏感,故选择run k means 10 times and choose best solution
     6     SimpleKMeans bestK = null;
     7     double bestSqE = Double.MAX_VALUE;
     8     for (i = 0; i < 10; i++) {
     9       SimpleKMeans sk = new SimpleKMeans();
    10       sk.setSeed(m_rr.nextInt());
    11       sk.setNumClusters(m_num_clusters);
    12       sk.setDisplayStdDevs(true);
    13       sk.buildClusterer(inst);
    14       //KMeans中各个cluster的平方误差
    15       if (sk.getSquaredError() < bestSqE) {
    16          
    17           bestSqE = sk.getSquaredError();
    18           bestK = sk;
    19       }
    20     }
    21     
    22     /*************** KMeans Finds the best cluster number *****************/
    23     
    24     
    25     // initialize with best k-means solution
    26     m_num_clusters = bestK.numberOfClusters();
    27     // 每个样本所在各个集群的概率
    28     m_weights = new double[inst.numInstances()][m_num_clusters];
    29     // 评估每个集群所对应的离散型属性的相关取值
    30
    m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs]; 31 // 每个集群所对应的连续性属性数所对应的相关取值(均值,标准偏差,样本权值(进行归一化)) 32 m_modelNormal = new double[m_num_clusters][m_num_attribs][3]; 33 // 每个集群所对应的先验概率 34 m_priors = new double[m_num_clusters]; 35 // 每个集群所对应的中心点 36 Instances centers = bestK.getClusterCentroids(); 37 // 每个集群所对应的标准差 38 Instances stdD = bestK.getClusterStandardDevs(); 39 // ??? Returns for each cluster the frequency counts for the values of each nominal attribute 40 int [][][] nominalCounts = bestK.getClusterNominalCounts(); 41 // 得到每个集群所对应的样本数 42 int [] clusterSizes = bestK.getClusterSizes(); 43 44 for (i = 0; i < m_num_clusters; i++) { 45 Instance center = centers.instance(i); 46 for (j = 0; j < m_num_attribs; j++) { 47 48 // 样本属性是离散型 49 if (inst.attribute(j).isNominal()) 50 { 51 m_model[i][j] = new DiscreteEstimator(m_theInstances.attribute(j).numValues() 52 , true); 53 for (k = 0; k < inst.attribute(j).numValues(); k++) { 54 m_model[i][j].addValue(k, nominalCounts[i][j][k]); 55 } 56 } 57 //// 样本属性是连续型 58 else 59 { 60 double minStdD = (m_minStdDevPerAtt != null)? m_minStdDevPerAtt[j]: m_minStdDev; 61 double mean = (center.isMissing(j))? inst.meanOrMode(j): center.value(j); 62 m_modelNormal[i][j][0] = mean; 63 double stdv = (stdD.instance(i).isMissing(j))? ((m_maxValues[j] - 64 m_minValues[j]) / (2 * m_num_clusters)): stdD.instance(i).value(j); 65 if (stdv < minStdD) 66 { 67 stdv = inst.attributeStats(j).numericStats.stdDev; 68 if (Double.isInfinite(stdv)) { 69 stdv = minStdD; 70 } 71 if (stdv < minStdD) { 72 stdv = minStdD; 73 } 74 } 75 if (stdv <= 0) { 76 stdv = m_minStdDev; 77 } 78 79 m_modelNormal[i][j][1] = stdv; 80 m_modelNormal[i][j][2] = 1.0; 81 } 82 } 83 } 84 85 86 for (j = 0; j < m_num_clusters; j++) { 87 // 计算每个集群的先验概率 88 m_priors[j] = clusterSizes[j]; 89 } 90 Utils.normalize(m_priors); 91 }
  • 相关阅读:
    keyCode 与charCode
    阻止事件冒泡的三种手段
    jquery实现二级菜单
    static public和 public static 区别
    java单例模式
    使用jqueryui
    正则表达式
    PHP中mysql_affected_rows()和mysql_num_rows()区别
    PHP中冒号、endif、endwhile、endfor这些都是什么
    jqueryMobile
  • 原文地址:https://www.cnblogs.com/likai198981/p/3170568.html
Copyright © 2011-2022 走看看