在编程珠玑这本书中,看到了有关蓄水池抽样算法的例子。就是说在一大堆不知道个数的数据中等概率随机抽取K个数据。
思路:如果我们知道n的值,那么问题就可以简单的用一个大随机数rand()%n得到一个确切的随机位置,那么该位置的对象就是所求的对象,选中的概率是1/n。但现在我们并不知道n的值,这个问题便抽象为蓄水池抽样问题,即从一个包含n个对象的列表S中随机选取k个对象,n为一个非常大或者不知道的值。通常情况下,n是一个非常大的值,大到无法一次性把所有列表S中的对象都放到内存中。我们这个问题是蓄水池抽样问题的一个特例,即k=1。
解法:我们总是选择第一个对象,以1/2的概率选择第二个,以1/3的概率选择第三个,以此类推,以1/m的概率选择第m个对象。当该过程结束时,每一个对象具有相同的选中概率,即1/n,证明如下。
证明:
对应问题伪代码:
//伪代码 i = 0 while more input items with probability 1.0 / ++i choice = this input item print choice
实现代码:
1 #include <iostream> 2 #include <cstdlib> 3 #include <ctime> 4 #include <vector> 5 6 using namespace std; 7 8 typedef vector<int> IntVec; 9 typedef typename IntVec::iterator Iter; 10 typedef typename IntVec::const_iterator Const_Iter; 11 12 // generate a random number between i and k, 13 // both i and k are inclusive. 14 int randint(int i, int k) 15 { 16 if (i > k) 17 { 18 int t = i; i = k; k = t; // swap 19 } 20 int ret = i + rand() % (k - i + 1); 21 return ret; 22 } 23 24 // take 1 sample to result from input of unknown n items. 25 bool reservoir_sampling(const IntVec &input, int &result) 26 { 27 srand(time(NULL)); 28 if (input.size() <= 0) 29 return false; 30 31 Const_Iter iter = input.begin(); 32 result = *iter++; 33 for (int i = 1; iter != input.end(); ++iter, ++i) 34 { 35 int j = randint(0, i); 36 if (j < 1) 37 result = *iter; 38 } 39 return true; 40 } 41 42 int main() 43 { 44 const int n = 10; 45 IntVec input(n); 46 int result = 0; 47 48 for (int i = 0; i != n; ++i) 49 input[i] = i; 50 if (reservoir_sampling(input, result)) 51 cout << result << endl; 52 return 0; 53 }
对应蓄水池抽样问题,可以类似的思路解决。先把读到的前k个对象放入“水库”,对于第k+1个对象开始,以k/(k+1)的概率选择该对象,以k/(k+2)的概率选择第k+2个对象,以此类推,以k/m的概率选择第m个对象(m>k)。如果m被选中,则随机替换水库中的一个对象。最终每个对象被选中的概率均为k/n,证明如下。
证明:第m个对象被选中的概率=选择m的概率*(其后元素不被选择的概率+其后元素被选择的概率*不替换第m个对象的概率),即
蓄水池抽样伪代码:
//伪代码 array S[n]; //source, 0-based array R[k]; // result, 0-based integer i, j; // fill the reservoir array for each i in 0 to k - 1 do R[i] = S[i] done; // replace elements with gradually decreasing probability for each i in k to n do j = random(0, i); // important: inclusive range if j < k then R[j] = S[i] fi done
实现代码(该版本假设直到n大小,但n非常大):
1 #include <iostream> 2 #include <cstdlib> 3 #include <ctime> 4 5 using namespace std; 6 7 // generate a random number between i and k, 8 // both i and k are inclusive. 9 int randint(int i, int k) 10 { 11 if (i > k) 12 { 13 int t = i; i = k; k = t; // swap 14 } 15 int ret = i + rand() % (k - i + 1); 16 return ret; 17 } 18 19 // take m samples to result from input of n items. 20 bool reservoir_sampling(const int *input, int n, int *result, int m) 21 { 22 srand(time(NULL)); 23 if (n < m || input == NULL || result == NULL) 24 return false; 25 for (int i = 0; i != m; ++i) 26 result[i] = input[i]; 27 28 for (int i = m; i != n; ++i) 29 { 30 int j = randint(0, i); 31 if (j < m) 32 result[j] = input[i]; 33 } 34 return true; 35 } 36 37 int main() 38 { 39 const int n = 100; 40 const int m = 10; 41 int input[n]; 42 int result[m]; 43 44 for (int i = 0; i != n; ++i) 45 input[i] = i; 46 if (reservoir_sampling(input, n, result, m)) 47 for (int i = 0; i != m; ++i) 48 cout << result[i] << " "; 49 cout << endl; 50 return 0; 51 }
实现代码(该版本不知道n大小):
1 #include <iostream> 2 #include <cstdlib> 3 #include <ctime> 4 #include <vector> 5 6 using namespace std; 7 8 typedef vector<int> IntVec; 9 typedef typename IntVec::iterator Iter; 10 typedef typename IntVec::const_iterator Const_Iter; 11 12 // generate a random number between i and k, 13 // both i and k are inclusive. 14 int randint(int i, int k) 15 { 16 if (i > k) 17 { 18 int t = i; i = k; k = t; // swap 19 } 20 int ret = i + rand() % (k - i + 1); 21 return ret; 22 } 23 24 // take m samples to result from input of n items. 25 bool reservoir_sampling(const IntVec &input, IntVec &result, int m) 26 { 27 srand(time(NULL)); 28 if (input.size() < m) 29 return false; 30 31 result.resize(m); 32 Const_Iter iter = input.begin(); 33 for (int i = 0; i != m; ++i) 34 result[i] = *iter++; 35 36 for (int i = m; iter != input.end(); ++i, ++iter) 37 { 38 int j = randint(0, i); 39 if (j < m) 40 result[j] = *iter; 41 } 42 return true; 43 } 44 45 int main() 46 { 47 const int n = 100; 48 const int m = 10; 49 IntVec input(n), result(m); 50 51 for (int i = 0; i != n; ++i) 52 input[i] = i; 53 if (reservoir_sampling(input, result, m)) 54 for (int i = 0; i != m; ++i) 55 cout << result[i] << " "; 56 cout << endl; 57 return 0; 58 }
本文参考:http://www.cnblogs.com/HappyAngel/archive/2011/02/07/1949762.html