Top K Frequent Elements (M)
题目
Given a non-empty array of integers, return the *k* most frequent elements.
Example 1:
Input: nums = [1,1,1,2,2,3], k = 2
Output: [1,2]
Example 2:
Input: nums = [1], k = 1
Output: [1]
Note:
- You may assume k is always valid, 1 ≤ k ≤ number of unique elements.
- Your algorithm's time complexity must be better than O(n log n), where n is the array's size.
- It's guaranteed that the answer is unique, in other words the set of the top k frequent elements is unique.
- You can return the answer in any order.
题意
求一个数组中出现次数最多的前k个元素。要求时间复杂度小于$O(NlogN)$。
思路
最直接的方法是维护一个最大大小为k的小顶堆,向其中添加数组中的不重复元素,按照出现次数排序,如果添加后堆大小为k+1,则将堆顶元素去除,这样最终留下的就是k个出现次数最多的元素,复杂度为$O(Nlogk)$。
top k类型的问题还可以用 快速选择算法 来解决,平均复杂度为$O(N)$。
代码实现
Java
优先队列
class Solution {
public int[] topKFrequent(int[] nums, int k) {
Map<Integer, Integer> record = new HashMap<>();
Queue<Integer> q = new PriorityQueue<>((a, b) -> record.get(a) - record.get(b));
for (int num : nums) {
record.put(num, record.getOrDefault(num, 0) + 1);
}
for (int num : record.keySet()) {
q.offer(num);
if (q.size() > k) {
q.poll();
}
}
int[] res = new int[k];
int i = 0;
while (!q.isEmpty()) {
res[i++] = q.poll();
}
return res;
}
}
快速选择
class Solution {
public int[] topKFrequent(int[] nums, int k) {
Map<Integer, Integer> record = new HashMap<>();
for (int num : nums) {
record.put(num, record.getOrDefault(num, 0) + 1);
}
int[] arr = new int[record.size()];
int i = 0;
for (int num : record.keySet()) {
arr[i++] = num;
}
int left = 0, right = arr.length - 1;
int[] res = new int[k];
while (left <= right) {
int mid = partition(arr, left, right, record);
if (k - 1 < mid) {
right = mid - 1;
} else if (k - 1 > mid) {
left = mid + 1;
} else {
for (int j = 0; j < k; j++) {
res[j] = arr[j];
}
return res;
}
}
return res;
}
private int partition(int[] arr, int left, int right, Map<Integer, Integer> record) {
int tmp = arr[left];
while (left < right) {
while (left < right && record.get(arr[right]) < record.get(tmp)) {
right--;
}
arr[left] = arr[right];
while (left < right && record.get(arr[left]) >= record.get(tmp)) {
left++;
}
arr[right] = arr[left];
}
arr[left] = tmp;
return left;
}
}