zoukankan      html  css  js  c++  java
  • Median of Two Sorted Arrays——算法课上经典的二分和分治算法

    There are two sorted arrays nums1 and nums2 of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

    看到这道题的通过率很诧异,感觉这道题挺容易的,因为其实它的思想还是很简单的。

    1)最笨的方法去实现,利用排序将两个数组合并成一个数组,然后返回中位数,这种方法应该会超时。

    2)利用类似merge的操作找到中位数,利用两个分别指向A和B数组头的指针去遍历数组,然后统计元素个数,直到找到中位数,此时算法复杂度为O(n)。

    我一开始想到的就是2)这种方法,但是真正的写起代码来,才发现有很多的细节需要去考虑,挺繁琐的。

    我们仅需要第k大的元素,不需要排序这个复杂的操作:可以定义一个计数器m,表示找到了第m大的元素;同时指针pa,pb分别指向数组A,B的第一个元素,使用merge-sort的方式,当A的当前元素小于B的当前元素时:pa++, m++,当*pb < *pa时,pb++, m++。最终当m等于k时,就得到了第k大的元素。时间复杂度O(k),但是当k接近于m+n时,复杂度还是O(m+n);

    3)

    从题目中的要求O(log(m+n))可以联想到肯定要用到二分查找的思想

    那么有没有更好的方案?我们可以考虑从k入手。如果我们每次能够删除一个一定处于第k大元素之前的元素,那么需要进行k次。但是如果我们每次都能删除一半呢?可以利用A,B有序的信息,类似二分查找,也是充分利用有序。 
    假设A 和B 的元素个数都大于k/2,我们将A 的第k/2 个元素(即A[k/2-1])和B 的第k/2个元素(即B[k/2-1])进行比较,有以下三种情况(为了简化这里先假设k 为偶数,所得到的结论对于k 是奇数也是成立的): 
    - A[k/2 - 1] == B[k/2 - 1];    

    就是这里感觉不太明白,如果k==3,感觉怎么凑都不对啊!!!!!

    解释:这里可能有人会有疑问,如果k为奇数,则m不是中位数。这里是进行了理想化考虑,在实际代码中略有不同,是先求k/2,然后利用k-k/2获得另一个数。
    - A[k/2 - 1] > B[k/2 - 1]; 
    - A[k/2 - 1] < B[k/2 - 1]; 
    如果A[k/2 - 1] < B[k/2 - 1] ,意味着 A[0] 到 A[k/2 - 1] 的元素一定小于 A+B 第k大的元素。因此可以放心的删除A数组中的这k/2个元素; 
    同理,A[k/2 - 1] > B[k/2 - 1];可以删除B数组中的k/2个元素; 
    当A[k/2 - 1] == B[k/2 - 1] 时,说明找到了第k大的元素,直接返回A[k/2 - 1] 或B[k/2 - 1]的值。

    因此可以写一个递归实现,递归终止条件是什么呢? 
    - A或B为空时,直接返回A[k-1] 或 B[k-1] 
    - 当k = 1时,返回min(A[0], B[0]) //第1小表示第一个元素 
    - 当A[k/2 - 1] == B[k/2 - 1] 时,返回A[k/2 - 1] 或B[k/2 - 1]

    我们可以看出,代码非常简洁,而且效率也很高。在最好情况下,每次都有k一半的元素被删除,所以算法复杂度为logk,由于求中位数时k为(m+n)/2,所以算法复杂度为log(m+n)。 

    转自:http://www.bubuko.com/infodetail-797383.html

    最终实现的代码为(代码是直接拷贝的别人的,到时候还需要自己写一下):

        static int find_kth(int* A, int m,int* B, int n, int k);
        static int min(p, q) {return (p < q) ? p : q;}
        double findMedianSortedArrays(int* nums1, int nums1Size, int* nums2, int nums2Size) {
            int m = nums1Size;
            int n = nums2Size;
            int total = m+n;
            int k = total/2;
            if(total & 0x01) {   //这里用total%2==1也行
                return find_kth(nums1, m, nums2, n, k+1); //奇数,返回唯一中间值
            } else {
                return (find_kth(nums1, m, nums2, n, k) + find_kth(nums1, m, nums2, n, k+1)) / 2.0; //偶数,返回中间2个的平均值
            }
        }
        //找到A,B组合中第k小的值: AB[k-1]
        int find_kth(int* A, int m,int* B, int n, int k) {
            //假设m都小于n
            if (m > n)
                return find_kth(B, n, A, m, k);
            if (m == 0)
                return B[k-1];
            if (k == 1) //终止条件
                return min(A[0], B[0]);
    
            int i_a = min(m, k/2);    //这两步的意义就是之前解释的遇到奇数时的情况
            int i_b = k - i_a;
    
            if (A[i_a-1] < B[i_b-1])
                return find_kth(A+i_a, m-i_a, B, n, k-i_a);
            else if (A[i_a-1] > B[i_b-1])
                return find_kth(A, m, B+i_b, n-i_b, k-i_b);
            else
                return A[i_a-1];
        }
    

      下面是我更改的C++的版:

    class Solution {
    public:
        int findk(vector<int>& nums1,int len1, vector<int>& nums2,int len2,int k)
        {
            
            if(len1>len2)
                return findk(nums2,len2,nums1,len1,k);
            if(len1==0)
                return nums2[k-1];
            if(k==1)
                return min(nums1[0],nums2[0]);
            int flag1=min(len1,k/2);
            int flag2=k-flag1;
            if(nums1[flag1-1]<nums2[flag2-1])
            {
                vector<int> temp(nums1.begin()+flag1,nums1.end());
                return findk(temp,len1-flag1,nums2,len2,k-flag1); 
            }
            else if(nums1[flag1-1]>nums2[flag2-1])
            {
                vector<int> temp(nums2.begin()+flag2,nums2.end());
                return findk(nums1,len1,temp,len2-flag2,k-flag2);
            }
                
            else
                return nums1[flag1-1];
        }
        
        double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
            int len1=nums1.size();
            int len2=nums2.size();
            int k=(len1+len2)/2;
            if((len1+len2)%2==1)
                return findk(nums1,len1,nums2,len2,k+1);
            else
                return (findk(nums1,len1,nums2,len2,k)+findk(nums1,len1,nums2,len2,k+1))/2.0;
        }
    };
    

      以下的是错误的,但是错在哪呢?

    class Solution {
    public:
        double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
            int len1=nums1.size();
            int len2=nums2.size();
            int mid=(len1+len2)/2;
            int res=0;
            if((len1+len2)%2==1)
                res=getRes(nums1,len1,nums2,len2,mid+1);
            else
                res=(getRes(nums1,len1,nums2,len2,mid)+getRes(nums1,len1,nums2,len2,mid+1))/2.0;
            return res;
        }
        double getRes(vector<int>& nums1, int len1,vector<int>& nums2,int len2,int mid)
        {
           
            if(len1<len2)
                return getRes(nums2,len2,nums1,len1,mid);
            if(len2==0)
                return nums1[mid-1];
            if(mid==1)
                return min(nums1[0],nums2[0]);
            int k1=min(len2,mid/2);
            int k2=mid-k1;
            int res=0;
            if(nums1[k1-1]<nums2[k2-1])
            {
                vector<int> temp( nums1.begin()+k1,nums1.end());
                res=getRes(temp,len1-k1,nums2,len2,mid-k1);
            }
            else  if(nums1[k1-1]>nums2[k2-1])
            {
                vector<int> temp( nums2.begin()+k2,nums2.end());
                res=getRes(nums1,len1,temp,len2-k2,mid-k2);
            }
            else
                res= nums1[k1-1];
            return res;
        }
        
    };
  • 相关阅读:
    LeetCode 914. 卡牌分组
    LeetCode 999. 车的可用捕获量
    LeetCode 892. 三维形体的表面积
    21航电5E
    min25筛 学习笔记
    牛客多校6G
    2021航电多校3
    2021牛客多校H
    [模版] 快速傅里叶变换
    2021牛客多校第五场
  • 原文地址:https://www.cnblogs.com/qiaozhoulin/p/4779003.html
Copyright © 2011-2022 走看看