zoukankan      html  css  js  c++  java
  • TopK问题,数组中第K大(小)个元素问题总结

    问题描述:

      在未排序的数组中找到第 k 个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

    面试中常考的问题之一,同时这道题由于解法众多,也是考察时间复杂度计算的一个不错的问题。

    1,选择排序

      利用选择排序,将数组中最大的元素放置在数组的最前端,然后第k次选择的最大元素就是第K大个元素,直接根据索引返回结果即可。

    public class Select {
        public static void main(String[] args) {
            int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9};
            System.out.println(findKthLargest(arr, 3));
        }
        private static int findKthLargest(int[] arr, int k){
            if(k <= 0 || k > arr.length)
                throw new IllegalArgumentException("k error");
            for(int i = 0; i < k; ++i){
                int maxNum = Integer.MIN_VALUE;
                int maxIndex = -1;
                for(int j = i; j < arr.length; ++j){
                    if(arr[j] > maxNum){
                        maxNum = arr[j];
                        maxIndex = j;
                    }
                }
                swap(arr, maxIndex, i);
            }
            System.out.println(Arrays.toString(arr));
            return arr[k-1];
        }
        private static void swap(int[] arr, int i, int j){
            int temp = arr[i];
            arr[i] = arr[j];
            arr[j] = temp;
        }
    }

      结果:

    [10, 9, 8, 1, 4, 7, 2, 5, 6, 3]
    8

      我们可以看到数组经过选择排序后,前三个元素分别是三趟选择中最大的元素,直接返回k-1索引位置的元素,即是第K大的元素。

      时间复杂度O(n*K),经过K次选择,每次选择都要遍历n个元素。

    2,排序优化

      上一个方法的本质实际上是将整个数组进行一个排序,然后根据索引位置得到答案,基于这个情况我们可以使用一些更快速的排序方法,例如选择排序或归并排序,以达到平局时间复杂度为O(nlogn)

    public class Sort {
        public static void main(String[] args) {
            int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9};
            System.out.println(findKthLargest(arr, 2));
        }
        private static int findKthLargest(int[] arr, int k){
            if(k <= 0 || k > arr.length)
                throw new IllegalArgumentException("k error");
            Arrays.sort(arr);
            System.out.println(Arrays.toString(arr));
            return arr[arr.length-k];
        }
    }

      结果:

    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    9

      时间复杂度O(nlogn),最坏时间复杂度根据不同的排序方法而不一样,快排的话就是O(n^2),归并排序是O(nlogn)。

    3,堆(优先队列)

      思路是创建一个最小堆,将所有数组中的元素加入堆中,并保持堆的大小小于等于 k。这样,堆中就保留了前 k 个最大的元素。这样,堆顶的元素就是正确答案。

    public class Heap {
        public static void main(String[] args) {
            int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9};
            System.out.println(findKthLargest(arr, 3));
        }
        private static int findKthLargest(int[] arr, int k){
            if(k <= 0 || k > arr.length)
                throw new IllegalArgumentException("k error");
            PriorityQueue<Integer> queue = new PriorityQueue<>((a,b)->{
                return a-b;
            });
            for(int num:arr){
                queue.offer(num);
                if(queue.size() > k)
                    queue.poll();
            }
            return queue.peek();
        }
    }

      时间复杂度是O(nlogk),向大小为 k 的堆中添加或删除元素的时间复杂度为O(logk),遍历n个元素,故总时间复杂度为 O(nlogk)

    4,快速选择

      基于快排的思想,选出一个基准元素,将数组划分成两部分,左侧的元素都比基准元素大,右侧的都比基准元素小,如果基准元素的索引恰好等于k-1,也就是说这个基准元素就是第k大的元素,否则根据基准元素的位置再去左边或者右边去选择。

    import java.util.PriorityQueue;
    import java.util.Random;
    
    public class QuickSelect {
        public static void main(String[] args) {
            int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9};
            System.out.println(findKthLargest(arr, 10));
        }
        private static int findKthLargest(int[] arr, int k){
            if(k <= 0 || k > arr.length)
                throw new IllegalArgumentException("k error");
            return quickSelect(arr, 0, arr.length-1, k);
        }
        private static int quickSelect(int[] arr, int left, int right, int k){
            if(left == right)
                return arr[left];
            Random random_num = new Random();
            int pivotIndex = left + random_num.nextInt(right - left);
            pivotIndex = partition(arr, left, right, pivotIndex);
            if(pivotIndex == k-1){
                return arr[pivotIndex];
            }else if(pivotIndex < k-1){
                return quickSelect(arr, pivotIndex+1, right, k);
            }else{
                return quickSelect(arr, left, pivotIndex-1, k);
            }
        }
        private static int partition(int[] arr, int left, int right, int pivotIndex){
            int pivot = arr[pivotIndex];
            swap(arr, pivotIndex, right);
            int l = left, r = right;
            while(l < r){
                while(l < r && arr[l] >= pivot)
                    l++;
                if(arr[l] < pivot)
                    swap(arr, l, r);
                while(l < r && arr[r] <= pivot)
                    r--;
                if(arr[r] > pivot)
                    swap(arr, l, r);
            }
            return l;
        }
        private static void swap(int[] arr, int i, int j){
            int temp = arr[i];
            arr[i] = arr[j];
            arr[j] = temp;
        }
    }

      这里我们选择一个数组中的随机值作为基准值,如果每次恰好都划分一半的元素的话,则T(n) = n + n/2 + n/4 + n/8 + n/16 + ... = 2n,也就是O(n)的时间复杂度。

      但如果每一次选择的元素恰好是最小值的话,时间复杂度则退化到了O(n^2)

      但是平均时间复杂度是O(n),算法导论上有严格的证明。

    5,BFPRT

      在BFPRT算法中,仅仅是改变了快速排序Partion中的pivot值的选取,在快速排序中,我们始终选择第一个元素或者最后一个元素作为pivot,而在BFPTR算法中,每次选择五分中位数的中位数作为pivot,这样做的目的就是使得划分比较合理,从而避免最坏情况的发生。算法步骤如下:

    1. 将输入数组的n个元素划分为n/5组,每组5个元素,且至多只有一个组由剩下的n%5个元素组成。
    2. 寻找n/5个组中每一个组的中位数,首先对每组的元素进行插入排序,然后从排序过的序列中选出中位数。
    3. 对于2中找出的n/5个中位数,递归进行步骤1和2,直到只剩下一个数即为这n/5个元素的中位数,找到中位数后并找到对应的下标p。
    4. 进行Partion划分过程,Partion划分中的pivot元素下标为p。
    5. 进行高低区判断即可

      本算法的最坏时间复杂度为O(n),值得注意的是通过BFPTR算法将数组按第K小(大)的元素划分为两部分,而这高低两部分不一定是有序的,通常我们也不需要求出顺序,而只需要求出前K大的或者前K小的。

    public class BFPRT {
        public static void main(String[] args) {
            int[] arr = new int[]{3,2,3,1,2,4,5,5,6};
            System.out.println(findKthLargest(arr, 4));
        }
        private static int findKthLargest(int[] arr, int k){
            if(k <= 0 || k > arr.length)
                throw new IllegalArgumentException("k error");
            return quickSelect(arr, 0, arr.length-1, k);
        }
        private static int findMedian(int[] arr, int l, int r){
            int i = l, index = 0;
            for(; i + 4 <= r; i += 5, index++){
                sort(arr, i, i + 4);
                swap(arr, l + index, i + 2);
            }
            if(i <= r){
                sort(arr, i, r);
                swap(arr, l+index, i + (r-i+1) / 2); //如果是最后数组元素是偶数选择较小的一个
                index++;
            }
            if(index == 1)
                return l;
            else
                return findMedian(arr, l, l+index-1);
        }
        private static int quickSelect(int[] arr, int left, int right, int k){
            if(left == right)
                return arr[left];
    //        Random random = new Random();
    //        int pivotIndex = left + random.nextInt(right - left);
            int pivotIndex = findMedian(arr, left, right);
            pivotIndex = partition(arr, left, right, pivotIndex);
            if(pivotIndex == k-1){
                return arr[pivotIndex];
            }else if(pivotIndex < k-1){
                return quickSelect(arr, pivotIndex+1, right, k);
            }else{
                return quickSelect(arr, left, pivotIndex-1, k);
            }
        }
        private static int partition(int[] arr, int left, int right, int pivotIndex){
            int pivot = arr[pivotIndex];
            swap(arr, pivotIndex, right);
            int l = left, r = right;
            while(l < r){
                while(l < r && arr[l] >= pivot)
                    l++;
                if(arr[l] < pivot)
                    swap(arr, l, r);
                while(l < r && arr[r] <= pivot)
                    r--;
                if(arr[r] > pivot)
                    swap(arr, l, r);
            }
            return l;
        }
        private static void swap(int[] arr, int i, int j){
            int temp = arr[i];
            arr[i] = arr[j];
            arr[j] = temp;
        }
        public static void sort(int[] arr, int l, int r){
            for(int i = l; i <= r; i++){
                for(int j = i+1; j <= r; j++){
                    if(arr[j] < arr[i])
                        swap(arr, i, j);
                }
            }
        }
    }

  • 相关阅读:
    事务的特性(ACID)
    网络代理
    防止SpringMVC拦截器拦截js等静态资源文件
    Tomcat配置虚拟目录
    SpringMVC总结(SSM)
    Spring声明式事务总结
    Linux中MySQL忽略表中字段大小写
    MySQL之sql文件的导入导出
    MyBatis总结
    Linux网络
  • 原文地址:https://www.cnblogs.com/silentteller/p/13166377.html
Copyright © 2011-2022 走看看