[本文链接:http://www.cnblogs.com/breezedeus/archive/2012/09/05/2671572.html,转载请注明出处]
在K-means聚类算法里,我们首先需要在已有的数据点中选取K个点作为初始中心点。这个bug就出现在中心点的随机选取上,mahout的实现不是真的随机。
【位置】:
org.apache.mahout.clustering.kmeans.RandomSeedGenerator#buildRandom(...) , 行 88 - 110 这段。
我简化了一下,mahout的随机抽取逻辑如下:
1: /**
2: * Sample K integers from integer interval [0, N).
3: * @param N
4: * @param K
5: * @return
6: */
7: private List<Integer> generateMahoutRandomSeed( int N, int K) {
8: List<Integer> chosen = Lists. newArrayListWithCapacity(K);
9: Random random = RandomUtils. getRandom();
10: for ( int n = 0; n < N; ++n) {
11: int currentSize = chosen.size();
12: if (currentSize < K) {
13: chosen.add(n);
14: } else if (random.nextInt(currentSize + 1) != 0) {
15: int indexToRemove = random.nextInt(currentSize); // evict one chosen randomly
16: chosen.remove(indexToRemove);
17: chosen.add(n);
18: }
19: }
20: return chosen;
21: }
1: private List<Integer> generateBDRandomSeed(int N, int K) {
2: List<Integer> chosen = Lists.newArrayListWithCapacity(K);
3: Random random = RandomUtils.getRandom();
4: for (int n = 0; n < N; ++n) {
5: int currentSize = chosen.size();
6: if (currentSize < K) {
7: chosen.add(n);
8: } else if (random.nextInt(n + 1) < K) {
9: int indexToRemove = random.nextInt(currentSize); // actually currentSize is always equal to K here
10: chosen.remove(indexToRemove);
11: chosen.add(n);
12: }
13: }
14: return chosen;
15: }
为了说明问题,我对上面的代码写了个测试函数:
1: @Test
2: public void testMahoutRandomSeedGenerator() {
3: int N = 11;
4: int K = 3;
5: int numLoops = 100000;
6: int[] times = new int [N];
7: Arrays. fill(times, 0);
8: for ( int loop = 0; loop < numLoops; ++loop) {
9: //List<Integer> chosen = generateMahoutRandomSeed(N, K);
10: List<Integer> chosen = generateWjlRandomSeed(N, K);
11: for (Integer i : chosen) {
12: ++times[i];
13: }
14: }
15: for ( int n = 0; n < N; ++n) {
16: System. out .println(times[n] / ( double)(numLoops*K) );
17: }
18: }
使用generateMahoutRandomSeed产生的结果如下:
0.03344333333333333
0.033523333333333336
0.033356666666666666
0.033036666666666666
0.044596666666666666
0.059013333333333334
0.07856
0.10567666666666667
0.14029
0.18848333333333334
0.25002
而使用修正后的generateBDRandomSeed,其结果如下:
0.09096
0.09011
0.09051666666666666
0.09143333333333334
0.09082
0.09034333333333333
0.09032
0.09103
0.09162
0.09148666666666666
0.09136
数学上可以证明我的算法是对的,证明可见我之前讲面试题目的一篇老博文。感兴趣的童鞋也可以想想为什么generateMahoutRandomSeed的实现不对:)
mahout的这个问题也可能出现在它其他随机抽取相关的代码中,所以建议用到mahout随机抽取代码的同学check一下再使用。