LeetCode 973. K Closest Points to Origin
题目描述
Given an array of points where points[i] = [xi, yi] represents a point on the X-Y plane and an integer k, return the k closest points to the origin (0, 0).
The distance between two points on the X-Y plane is the Euclidean distance (i.e., √((x1 - x2)^2 + (y1 - y2)^2)).
You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).
Example 1:
Input: points = [[1,3],[-2,2]], k = 1
Output: [[-2,2]]
Explanation:
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest k = 1 points from the origin, so the answer is just [[-2,2]].
Example 2:
Input: points = [[3,3],[5,-1],[-2,4]], k = 2
Output: [[3,3],[-2,4]]
Explanation: The answer [[-2,4],[3,3]] would also be accepted.
Constraints:
- 1 <= k <= points.length <= 104
- -104 < xi, yi < 104
解题思路
前k大元素和第k大元素可以视为同一类题,都是 topK 问题,经典解法有堆和快速选择算法两种。
思路一:堆
C++ 中有现成的 priority_queue 可用,注意默认是大根堆,并且自定义 comparator 注意写法。
时间复杂度 O(K+NlogK),空间复杂度 O(K)。
思路二:快速选择算法
C++中有 nth_element 可用,同样要注意函数参数、以及 comparator 的写法。
我们也可以选择手写快速选择算法。
时间复杂度 O(N),空间复杂度 O(1)。
参考代码
这里比较的是到原点的距离,我们可以直接比较 x^2 + y^2 而不必开平方,比较结果是一样的。
堆
/*
* @lc app=leetcode id=973 lang=cpp
*
* [973] K Closest Points to Origin
*/
// @lc code=start
class Solution {
public:
// 堆
vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
assert(1 <= k && k <= points.size());
using P = vector<int>;
auto cmp = [&](const P& a, const P& b) {
return a[0]*a[0] + a[1]*a[1] < b[0]*b[0] + b[1]*b[1];
};
priority_queue<P, deque<P>, decltype(cmp)> q(cmp);
for (auto&& p : points) {
q.push(p);
if (q.size() > k) {
q.pop();
}
}
vector<vector<int>> res;
while (!q.empty()) {
res.push_back(q.top());
q.pop();
}
return res;
} // AC
};
// @lc code=end
快速选择 nth_element
// nth_element(begin, nth, end)
vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
assert(1 <= k && k <= points.size());
using P = vector<int>;
auto cmp = [&](const P& a, const P& b) {
return a[0]*a[0] + a[1]*a[1] < b[0]*b[0] + b[1]*b[1];
}; // OK
nth_element(points.begin(), points.begin()+k, points.end(), cmp);
// nth_element(points.begin(), points.begin()+k, points.end(),
// [](auto&& a, auto&& b){
// return a[0]*a[0] + a[1]*a[1] < b[0]*b[0] + b[1]*b[1];
// }); // OK
points.resize(k);
return points;
} // AC
手写快速选择算法
static inline int dist(vector<int>& p) {
return p[0]*p[0] + p[1]*p[1];
}
int partation(vector<vector<int>>& points, int l, int r) {
auto pivot = points[l];
while (l < r) {
while (l < r && dist(points[r]) >= dist(pivot)) r--;
if (l < r) points[l++] = points[r];
while (l < r && dist(points[l]) < dist(pivot)) l++;
if (l < r) points[r--] = points[l];
}
points[l] = pivot;
return l;
}
void find_kth(vector<vector<int>>& points, int k) {
int l = 0, r = points.size() - 1;
while (l < r) {
int t = partation(points, l, r);
if (t == k) return;
else if (t < k) l = t + 1;
else r = t - 1;
}
}
vector<vector<int>> kClosest(vector<vector<int>>& points, int k) {
assert(1 <= k && k <= points.size());
find_kth(points, k-1);
points.resize(k);
return points;
} // AC