水塘抽样 与 洗牌算法
本文介绍两个相似的问题,水塘抽样和洗牌算法。
水塘抽样(Reservoir Sampling)
水塘抽样(Reservoir Sampling)说的是这样一个问题:当内存无法完全加载时,如何从数据流或大数据集中随机选取k个样本,并保证每个样本被选取的概率相等。
典型问题出现在高纳德《计算机编程艺术》和谷歌面试题中都出现过:可否在一未知大小的集合中,随机取出一元素?/ 在不知道文件总行数的情况下,如何从文件中随机的抽取一行?
这个问题我们可以分两种情况进行讨论:
- k == 1
- k > 1
当 k == 1 的时候,我们可以在每次遇到合法对象时,以 1/n 的概率决定是否替换结果对象,其中 n 是当前遇到过的合法对象数目。显然,第一次遇到的时候肯定会替换成结果对象;第二次遇到的时候有一半可能替换,也就是前两个合法对象都有一半可能返回结果;第三次遇到的时候有 1/3 可能替换,前两个合法对象返回的可能都出自剩下这 2/3 可能,从而前三个合法对象返回的概率也一样 …… 归纳法可证所有合法对象返回的概率都相同,且概率总和为 1。
当 k > 1 的时候,我们在前 k 次遇到合法对象的时候直接存入结果数组;之后每一次遇到合法对象,都以 k/n 的可能来替换结果对象,结果数组中每个对象都等概率分到 1/k 份替换概率。参照上面的证明,归纳法可证所有合法对象返回的概率都相同,且概率总和为 k。
代码模板
// k == 1
int cnt = 0;
for (int i=0; i<arr.size(); i++) {
if (arr[i] != target) continue;
if ((rand() % ++cnt) == 0) res = i;
}
// k > 1
int cnt = 0;
for (int i=0; i<arr.size(); i++) {
if (arr[i] != target) continue;
if (cnt < k) {
res[cnt++] = i;
} else {
int j = (rand() % ++cnt);
if (j < k) res[j] = i;
}
}
下面给出几道例题:
LeetCode 398. Random Pick Index
保证指定的target一定会出现在数组中,可能出现多次,要求等概率给出其中一个合法下标。同时提示考虑内存较小的实际情况。
这就是典型的水塘抽样问题了。
/*
* @lc app=leetcode id=398 lang=cpp
*
* [398] Random Pick Index
*/
// @lc code=start
/*
class Solution {
unordered_map<int, vector<int>> val2idx;
public:
Solution(vector<int>& nums) {
for (int i=0; i<nums.size(); i++) {
val2idx[nums[i]].push_back(i);
}
}
int pick(int target) {
assert(!val2idx[target].empty());
auto&& v = val2idx[target];
return v[rand() % v.size()];
}
}; // AC, O(N) space, O(1) time
*/
// The array size can be very large.
// Don't use too much extra space.
class Solution {
vector<int> arr;
public:
Solution(vector<int>& nums) : arr(std::move(nums)) {}
int pick(int target) {
int cnt = 0;
int res = -1;
for (int i=0; i<arr.size(); i++) {
if (arr[i] != target) continue;
if ((rand() % ++cnt) == 0) {
res = i;
}
}
return res;
} // AC, O(1) space, O(N) time
};
/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(nums);
* int param_1 = obj->pick(target);
*/
// @lc code=end
这道题程序检查不严谨,用 O(N) 的内存也可以通过。但是在实际的应用场景中,这是不应该的,就连保存整个数组也不应该允许,只能是 std::move
符合要求,但是这会让不会新标准C++的人很为难 ……
LeetCode 382. Linked List Random Node
这道题换成了从链表中随机返回一个结点的元素,本质上没有变化,只要能顺序访问元素就可以,算法本身不要求随机访问。
/*
* @lc app=leetcode id=382 lang=cpp
*
* [382] Linked List Random Node
*/
// @lc code=start
/**
* Definition for singly-linked list.
* struct ListNode {
* int val;
* ListNode *next;
* ListNode() : val(0), next(nullptr) {}
* ListNode(int x) : val(x), next(nullptr) {}
* ListNode(int x, ListNode *next) : val(x), next(next) {}
* };
*/
class Solution {
public:
/** @param head The linked list's head.
Note that the head is guaranteed to be not null, so it contains at least one node. */
Solution(ListNode* head) {
ptr = head;
}
/** Returns a random node's value. */
int getRandom() {
int res = -1;
size_t n = 0;
ListNode* p = ptr;
while (p) {
if ((rand() % ++n) == 0) {
res = p->val;
}
p = p->next;
}
return res;
}
private:
ListNode* ptr;
}; // AC, Reservoir Sampling
/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(head);
* int param_1 = obj->getRandom();
*/
// @lc code=end
洗牌算法
高端的洗牌,往往只需要最简单的一遍扫描。(_)
大名鼎鼎的 Knuth shuffle,思想与前面的水塘抽样有些类似。不过需要注意两点:
- 从后往前扫,这样每个元素至多被交换一次,是确定性的;
- 交换时,随机选取的交换对象包括自身原本所在位置,保证每个元素都能取到每个位置。
概率的证明方法类似,也可用归纳法:
第一次交换,任何元素出现在最后一个位置的概率是 1/n;
第二次交换,任何元素出现在倒数第二个位置的概率是 (1-1/n) * 1/(n-1) = 1/n;
...
直到就剩一个元素不用交换。
代码模板
for (int i=arr.size()-1; i>0; i--) {
std::swap(arr[i], arr[rand() % (i+1)]);
}
LeetCode 384. Shuffle an Array
我们可以考虑手写洗牌算法,也可以直接调用 std::random_shuffle
来洗牌。或者,我们可以自定义生成器来使用 std::shuffle
洗牌,这个写起来更复杂一些。
/*
* @lc app=leetcode id=384 lang=cpp
*
* [384] Shuffle an Array
*/
// @lc code=start
class Solution {
public:
Solution(vector<int>& nums) : src(std::move(nums)) {}
/** Resets the array to its original configuration and return it. */
vector<int> reset() {
return src;
}
/** Returns a random shuffling of the array. */
vector<int> shuffle() {
if (src.empty()) return src;
vector<int> tmp(src);
// std::random_shuffle(tmp.begin(), tmp.end()); // AC
for (int i=tmp.size()-1; i>0; i--) {
std::swap(tmp[i], tmp[rand() % (i+1)]);
} // AC
return tmp;
}
private:
const vector<int> src;
};
/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(nums);
* vector<int> param_1 = obj->reset();
* vector<int> param_2 = obj->shuffle();
*/
// @lc code=end