最后更新
四刷?
K largest什么的题是面试的高频题,这次尝试搞清楚。
首先是双指针,锁定区间,每次找一个pivot,然后按这个分边。
我是选一个pivot,然后比Pivot大的放左边,小的放右边,然后看看K落在P的左边还是右边,来缩小区间。
最坏的情况每次都找到最小的元素,而我们需要最大的(k=1),这样就是O(n²)
具体根据pivot是每次从左从右找2个能交换的。
public class Solution {
public int findKthLargest(int[] nums, int k) {
return partition(0, nums.length - 1, k, nums);
}
public int partition(int start, int end, int k, int[] nums) {
int pivot = nums[start];
int l = start;
int r = end;
while (l <= r) {
while (l <= r && nums[l] >= pivot) l ++;
while (l <= r && nums[r] <= pivot) r --;
if (l < r) {
swap(l, r, nums);
}
}
swap(start, r, nums);
if (r + 1 == k) {
return nums[r];
} else if (r + 1 < k) {
return partition(r+1, end, k, nums);
} else {
return partition(start, r-1, k, nums);
}
}
public void swap(int l, int r, int[] nums) {
int temp = nums[l];
nums[l] = nums[r];
nums[r] = temp;
}
}
另一种做法是借鉴sort color的partition方法。
partition的时候卡了一下,最好的办法还是把大于PIVOT的放左边。最后R停留的位置一定是最后一个大于Pivot的值。
public class Solution {
public int findKthLargest(int[] nums, int k) {
return partition(0, nums.length - 1, k, nums);
}
public int partition(int start, int end, int k, int[] nums) {
int pivot = nums[start];
int left = start;
int right = end;
int temp = start;
while (temp <= right) {
if (nums[temp] > pivot) {
swap(temp ++, left ++, nums);
} else if (nums[temp] == pivot) {
temp ++;
} else {
swap(temp, right --, nums);
}
}
if (right + 1 == k) {
return nums[right];
} else if (right + 1 < k) {
return partition(right+1, end, k, nums);
} else {
return partition(start, right-1, k, nums);
}
}
public void swap(int l, int r, int[] nums) {
int temp = nums[l];
nums[l] = nums[r];
nums[r] = temp;
}
}
时间复杂度是一样的,最差的情况还是类似于54321 的数组然后K=5这种。
二刷。
这个题还有印象,用的quick select,quick sort的一部分。
时间上是需要证明为什么是O(n)。。
最坏的情况是pivot每次选极端值,最后就是n2,这个不用说了。
最好的情况是我们每次都选的是中间值,那最终结果就是:
n/2 + n/4 + n/8 + .. + n/n = n - 1
就是O(n)。
为了保证我们取值合适,得用适当的方式来选取pivot,而不是随机从里面区间抓一个。
public class Solution {
public int findKthLargest(int[] nums, int k) {
if (nums.length == 1) return nums[0];
return quickSelect(0,nums.length-1,k,nums);
}
public int quickSelect(int l, int r, int k, int[] nums) {
int m = getMid(l,r,nums);
int target = nums[m];
swap(nums,r,m);
int left = l;
int i = l;
while (i < r) {
if (nums[i] < target) {
swap(nums, i, left);
left++;
}
i++;
}
swap(nums,left,r);
if (left == nums.length - k) {
return nums[left];
} else if (left > nums.length - k) {
return quickSelect(l, left-1, k, nums);
} else {
return quickSelect(left+1, r, k, nums);
}
}
public int getMid(int l, int r, int[] nums) {
int a = nums[l];
int b = nums[r];
int m = l + (r - l) / 2;
int c = nums[m];
if (a > b) {
if (b > c) return r;
else return a > c? m : l;
} else {
if (a > c) return l;
else return b > c? m: r;
}
}
public void swap(int[] nums, int a, int b) {
int temp = nums[a];
nums[a] = nums[b];
nums[b] = temp;
}
}
看讨论区发现自己发的POST,我他妈二刷居然不如一刷写的好。。。
https://discuss.leetcode.com/topic/55501/2ms-java-quick-select-only-2-points-to-mention
三刷。
首先用正常的PQ来做,维持大小为K的minHeap,所有元素往里面加,多了就POLL出去。。
最后顶上的就是要求的。
Time: O(n lgk)
Space: O(k)
public class Solution {
public int findKthLargest(int[] nums, int k) {
if (nums.length == 1) return nums[0];
PriorityQueue<Integer> pq = new PriorityQueue<Integer>(k);
for (int i : nums) {
pq.offer(i);
if (pq.size() > k) {
pq.poll();
}
}
return pq.poll();
}
}
然后quick select, 基本想法是,只继续SORT可能存在K的那半部分。
每次“尽量"选一个合适的pivot值,然后进行partition,这个地方刚卡了一下= = 记录的应该是小于pivot的而不是大于pivot的。。
Time: O(n) average.. O(n²) worst..
public class Solution {
public int findKthLargest(int[] nums, int k) {
if (nums.length == 1) return nums[0];
return quickSelect(nums, k, 0, nums.length - 1);
}
public int quickSelect(int[] nums, int k, int l, int r) {
int m = betterMid(nums, l ,r);
int pivot = nums[m];
swap(nums, m, r);
int smaller = l;
for (int i = l; i < r; i++) {
if (nums[i] < pivot) {
swap(nums, i, smaller ++);
}
}
swap(nums, smaller, r);
if (smaller + k == nums.length) {
return nums[smaller];
} else if (smaller + k > nums.length) {
return quickSelect(nums, k, l, smaller - 1);
} else {
return quickSelect(nums, k, smaller + 1, r);
}
}
public void swap(int[] nums, int indexA, int indexB) {
int temp = nums[indexA];
nums[indexA] = nums[indexB];
nums[indexB] = temp;
}
public int betterMid(int[] nums, int l, int r) {
int m = l + (r - l) / 2;
if (nums[l] > nums[r]) {
if (nums[m] > nums[r]) {
return nums[l] > nums[m] ? m : l;
} else {
return r;
}
} else {
if (nums[m] > nums[l]) {
return nums[r] > nums[m] ? m : r;
} else {
return l;
}
}
}
}