问题描述

There are two sorted arrays A and B of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

想法与实现

这道题跟前面的题目都有些不一样,不一样的地方在于它规定了算法的时间复杂度。一开始我想先写一个简单的AC版本,先合并有序的数组,然后再在常数时间内找到他们的Median Number。这样的时间复杂度是O(min(m, n))。实现很简单,就不贴了。

本来是想看看这个方法的报错会是什么,没想到直接AC了。这说明他们的OJ服务器应该不检查代码的时间复杂度。你把合并的数组用sort方法去合并,还是能过。

后来就开始想题目要求复杂度的算法了,看到log,应该比较自然会想到二分,或者说分治。我最先想到的是快速选择。借鉴这种思想,我觉得是可以这样来实现的:将中位数的问题转换为求第K大元素的问题。而求解第K大元素的问题,应该怎么来做呢。首先,这是有序数组,那求第K大问题应该还是比较自然可以想到,类似二分查找的做法。而这个问题的二分查找跟一般的二分查找不同的地方在于,他是对两个数组进行二分查找,需要不断缩小两个数组的搜索范围。

log(m) + log(n)

我自己做了半天,老是遇到问题,于是就放弃了= =后来参考了一篇文章,感觉这种方法,还是挺难想的。首先需要说一个简单的观察:

Maintaining the invariant
i + j = k – 1,
If Bj-1 <= Ai <= Bj, then Ai must be the k-th smallest,
Else if Ai-1 <= Bj <= Ai, then Bj must be the k-th smallest.

这个观察,就是在i+j=k-1的基础上,确定第K小的元素的方法。那接下来就要看看如何缩减这个问题的规模。缩减规模,也是建立在一个观察上的。大致就是说,如果在满足上面的约束的情况下,如果Ai和Bj都不是第K大的元素,那么必定一个大于第K大的元素,另外一个小于第K大的元素。所以就可以对问题的范围进行缩减了,每次判断,对于小的那个元素,可以排除掉在该数组中比该元素更小的那些,对于大的元素,可以排除掉比他更大的元素,这样就可以不断的减小问题的规模。整个的代码就是这样的~

double findKthElement(int A[], int m, int B[], int n, int k) 
{
    int i = (int)((double)m / (m+n) * (k-1));
    int j = (k-1) - i;
    // invariant: i + j = k-1
    // Note: A[-1] = -INF and A[m] = +INF to maintain invariant
    int Ai_1 = ((i == 0) ? INT_MIN : A[i-1]);
    int Bj_1 = ((j == 0) ? INT_MIN : B[j-1]);
    int Ai   = ((i == m) ? INT_MAX : A[i]);
    int Bj   = ((j == n) ? INT_MAX : B[j]);

    if (Bj_1 <= Ai && Ai <= Bj)
        return Ai;
    else if (Ai_1 <= Bj && Bj <= Ai)
        return Bj;

    // if none of the cases above, then it is either:
    if (Ai < Bj)
    // exclude Ai and below portion
    // exclude Bj and above portion
        return findKthElement(A+i+1, m-i-1, B, j, k-i-1);
    else /* Bj < Ai */
    // exclude Ai and above portion
    // exclude Bj and below portion
        return findKthElement(A, i, B+j+1, n-j-1, k-j-1);
}

double findMedianSortedArrays(int A[], int m, int B[], int n) {
	int k = (m + n + 1) / 2;
    if ((m + n) % 2 == 1)
        return findKthElement(A, m, B, n, k);
    else
    {
        return ((double) findKthElement(A, m, B, n, k) + (double) findKthElement(A, m, B, n, k + 1)) / 2.0;
    }
}

评论