算法专题:Merge Sort

说起归并排序(Merge Sort),其在排序界的地位可不低,毕竟O(nlogn)比较排序的三大排序方法,就是Quick Sort, Merge Sort和Heap Sort。归并排序是典型的分而治之方法,先来看看其最简单的递归实现:

def merge_sort(lst):
    """Sortsthe input list using the merge sort algorithm.

    # >>> lst = [4, 5, 1, 6, 3]
    # >>> merge_sort(lst)
    [1, 3, 4, 5, 6]
    """
    if len(lst) <= 1:
        return lst
    mid = len(lst) // 2
    left = merge_sort(lst[:mid])
    right = merge_sort(lst[mid:])
    return merge(left, right)

def merge(left, right):
    """Takestwo sorted lists and returns a single sorted list by comparing the
    elements one at a time.

    # >>> left = [1, 5, 6]
    # >>> right = [2, 3, 4]
    # >>> merge(left, right)
    [1, 2, 3, 4, 5, 6]
    """
    if not left:
        return right
    if not right:
        return left
    if left[0] < right[0]:
        return [left[0]] + merge(left[1:], right)
    return [right[0]] + merge(left, right[1:])

很明显,归并排序是典型的分而治之(Divide and Conquer,D&C)算法,思想就是先把两半数据分别排序,然后再归并到一起。这样T(n) = 2T(n/2) + O(n),由Master Theorem可以得到其时间复杂度是O(nlogn)。
再看具体的实现。排序主体函数用的是递归,归并算法一般都是这样;而merge部分其实也可以用迭代来完成:

def merge_2(left, right):
    p1 = p2 = 0
    temp = []
    while p1 < len(left) and p2 < len(right):
        if left[p1] <= right[p2]:
            temp.append(left[p1])
            p1 += 1
        else:
            temp.append(right[p2])
            p2 += 1
    while p1 < len(left):
        temp.append(left[p1])
        p1 += 1
    while p2 < len(right):
        temp.append(right[p2])
        p2 += 1
    return temp 

单纯就此情景而言,迭代的merge显得冗长而且效率没有提升。但是其好处就是适用性广,因为有很多merge sort的变形,不太方便递归调用merge函数。
变形:merge sort有很多tweak的应用,大部分是需要考虑数组前后关系。

例1. 给出一个数组nums,需要对每个数其index之后比它大的数的个数求和。例如给出[7, 4, 5, 2, 8, 9, 0, 1],返回11,因为7后面有8,9两个比它大的,4有3个,5有2个,2有2个,8有1个,0有1一个,总共2+3+2+2+1+1=11。

【解】
Method 1:一个naive的方法就是对于每一个数,遍历搜索其后面所有比其大的数,显然时间复杂度是O(n^2)。

Method 2:还有一个方法就是考虑Segment Tree,先构建从min到max的线段树,O(max(n) - min(n)),初始count都是0。然后从反方向考虑,考虑前面比其小的有多少个。也就是说对于某个n[i],考虑[min(n), n[i]-1]这个区间里面有多少count。完了再把n[i]的count++。就这个例子而言,先放7,然后4进来的时候搜索[0,3]区间,因为是要比4小,再把4的count设置为1;这样5进来的时候,就能搜索到4的存在。
这个算法后面的步骤是O(nlogn),但是需要构造一个线段树,假如max很大很大,就不太合适。当然也可以argue说我就干脆构造一个囊括最小到最大的32位int的线段树,这还是O(1)呢XD。

Method 3:这个解法考虑使用merge sort的tweak。因为要求每个n[i]之后比它大的数,可以用分而治之的思想,即考虑前一半有多少个,后一半有多少个,然后再考虑之间有多少个。
就这个例子而言,考虑最后一次merge之前的样子:[2,4,5,7][0,1,8,9],此时两半里面都已经计算完毕,只需要计算merge时候产生的结果。很明显结果是8,因为前一半的4个数都比8和9要小。但是如何计算呢?
考虑到后一半已经排好序,假如对后一半使用binary search,自然可以得到第一个大于前一半某一个数的index,从而获得所有大于这个数的个数。也就是说这个merge是O(nlogn)。那么总体就是T(n) = T(n/2) + O(nlogn),由Master Theorem可知复杂度是O(n(logn)^2)。

Method 4:但是,上面这个merge方法没有利用前一半也排好序的条件,因此可以做到更好。
考虑两个指针p1和p2,分别指向前一半和后一半。p1初始是在2,p2在0,因为此时n[p2] < n[p1]因此p2增加,直至指向8。那么因为第二半是递增的,p2后面的数肯定也满足,因此这时候就可以获得大于n[p1]的个数:第二半的长度-p2。然后呢?p1递增指向4,假如p2重新回到0然后扫描,这个复杂度就是O(n^2),比上面的二分查找还要差。
因此做一些调整,不递增p1,而是递增p2。也就是说换一个思路,不是从第二半里面找比第一半大的,而是从第一半里面找比第二半小的:刚开始还是p1指向2,p2指向0,然后因为n[p1] >= n[p2],因为第一半递增,后面的肯定也比n[p2]要大,因此没必要往后看,可以直接计算个数:p1-s,s是递归使用的开始的index,这里是0,也就是说对于n[p2]没有比其更小的。
然后递增p2,但需要注意的是p1不用复位,这是很关键的一点。为什么?因为p1停止的条件,要么就是已经扫完整个一半了,要么就是现在的n[p1]比n[p2-1]要大,也就是说现在的p1之前的都比n[p2-1]要小,而n[p2]>n[p2-1],因此前面那些根本就不需要比较就能知道结论,可以直接沿用之前的p1的位置。
在这个例子里面,比较明显的就是8和9.对于8,p1将会递增至第一半的长度,也就是说整个第一半都比8要小,那么对于9而言,比8大,因此整个第一半也都比9小,无需再从头比较。
再举一个一般性一点的例子:[2,4,5,7][0,1,6,8],对于6,p1将会停在7上面,计数是3;p2递增后,对于8,可以知道p1前面都是比6小的,那么肯定也就比8小,因此直接从p1在7上面开始,最后计数是4。
这样一来,merge函数两个指针就不需要走回头路,效率O(n),整体效率是O(nlogn),空间复杂度O(n)。当然,具体实现的时候,还是要把两半真正的merge排好序,因为上面的计算都是在两边都排好序的情况下进行的。当只有一个元素的时候可以直接返回0。代码如下:

# count number that larger than it and after it
def dc2(self, n, s, e):
    if s >= e:
        return 0
    m = (s + e) // 2
    ans = self.dc2(n, s, m) + self.dc2(n, m + 1, e)
    p1 = s
    for q in range(m + 1, e + 1):
        while p1 <= m and n[q] > n[p1]:
            p1 += 1
        ans += p1 - s
    # merge
    temp = []
    p1, p2 = s, m + 1
    while p1 <= m and p2 <= e:
        if n[p1] <= n[p2]:
            temp.append(n[p1])
            p1 += 1
        else:
            temp.append(n[p2])
            p2 += 1
    while p1 <= m:
        temp.append(n[p1])
        p1 += 1
    while p2 <= m:
        temp.append(n[p2])
        p2 += 1
    for i in range(len(temp)):
        n[i + s] = temp[i]
    return ans 

例2. 给出一个数组n和一个范围[a, b],求n有多少个子区间的和在[a,b]之内。假设数组n的元素和a,b都是整数。例如给出[2,3,4,1],和范围[3,5],那么子区间[3][4][2,3][4,1]都满足条件,返回4。

【解】
Method 1:naive方法就是找出所有的子区间,然后看有多少个满足条件。复杂度非常高。

Method 2:看到子区间之和,当然想到prefix sum。也就是说可以造一个数组s,每一个元素s[i] = n[0]+...+n[i]。那么所有的子区间除了n[0]都可以用s的后一个元素减去前一个元素获得。
也就是说,问题转换成为:给出一个数组s,计算有多少对ij,使得s[i] - s[j] in [a,b]而且i < j?
假如s是升序的,那么好说;但s是无序的。Naive方法就是对每一个s[i],都扫一遍后面元素看看能不能满足在区间a+s[i],b+s[i]里面,假如满足那么减去s[i]就在要求的区间里面。当然最后还需要比较一下单个的s元素。这个做法复杂度O(n^2)。

Method 3:在子区间prefix sum的基础上,考虑merge sort的tweak。假设n=[7, 4, 5, 2, 8, 9, 0, 1], a=0, b=7。
考虑最后一次merge之前的情况:[2,4,5,7][0,1,8,9].从上一题得到启发,假如对第一半里的每一个数s[i],在第二半里面二分查找第一个大于等于s[i]+a的index1,假如index1不存在那就不需要再找了,没有符合条件的;和第一个大于s[i]+b的index2,假如不存在那么index2=e也就是end的index。那么自然就可以得到个数index2 - index1。这样merge的复杂度是O(nlogn),总体O(n(logn)^2)。

Method 4:
在Method 3的基础上改进。类似于例1,Method 3的问题还是在于没有利用第一半排好序的条件。
考虑三个指针,p1p2和q,p1p2都指向第一半,p2指向第二半。 因为要利用第一半排序的条件,因此还是固定递增q。对于s[q],需要s[q] - s[i] 在区间[a,b]当中。也就是说s[q] - s[i] >= a, s[i] <= s[q] - a; s[q] - s[i] <= b, s[i] >= s[q] - b。
因此两个指针p1p2,p1不断递增直至不满足s[p1] <= s[q] - a,p2不断递增直至不满足s[p2] < s[q] - b。那么,p1之前的都是满足s[q] - s[i] >= a的,p2之后的都是满足s[q] - s[i] <= b,p1p2之间的就是满足条件的,即count+=p1-p2.
然后递增q,因为s[q] >= s[q-1],因此之前p1p2的位置可以延续,即s[q] - a >= s[q-1] - a >= s[i],也就是说p1之前p2之前的元素还是满足那些条件。因此,这个merge函数的复杂度是O(n),总体时间复杂度O(nlogn),空间复杂度O(n)。注意单个区间的情况已经被涵盖了。代码如下:

class Solution:
    # count numbers of subarray sum in range of [a,b]
    def countSubarraySum(self, nums, a, b):
        if not nums:
            return 0
        n = [0] * len(nums)
        for i in range(len(nums)):
            if i != 0:
                n[i] = nums[i] + n[i - 1]
            else:
                n[i] = nums[i]
        return self.dc(n, a, b, 0, len(n) - 1)

    # count number of prefix sum that x[i] - x[j] in [a, b] and i > j plus itself in [a, b]
    def dc(self, n, a, b, s, e):
        if s > e:
            return 0
        if s == e:
            return a <= n[s] <= b
        m = (s + e) // 2
        ans = self.dc(n, a, b, s, m) + self.dc(n, a, b, m + 1, e)
        p1 = p2 = s
        for q in range(m + 1, e + 1):
            while p1 <= m and n[q] - n[p1] >= a:
                p1 += 1
            while p2 <= m and n[q] - n[p2] > b:
                p2 += 1
            if p2 <= p1:
                ans += p1 - p2
        # merge
        temp = []
        p1, p2 = s, m + 1
        while p1 <= m and p2 <= e:
            if n[p1] <= n[p2]:
                temp.append(n[p1])
                p1 += 1
            else:
                temp.append(n[p2])
                p2 += 1
        while p1 <= m:
            temp.append(n[p1])
            p1 += 1
        while p2 <= m:
            temp.append(n[p2])
            p2 += 1
        for i in range(len(temp)):
            n[i + s] = temp[i]
        return ans

推荐阅读更多精彩内容

  • 背景 一年多以前我在知乎上答了有关LeetCode的问题, 分享了一些自己做题目的经验。 张土汪:刷leetcod...
    张土汪阅读 10,334评论 0 31
  • 贪心算法 贪心算法总是作出在当前看来最好的选择。也就是说贪心算法并不从整体最优考虑,它所作出的选择只是在某种意义上...
    fredal阅读 7,175评论 3 51
  • 概述 排序有内部排序和外部排序,内部排序是数据记录在内存中进行排序,而外部排序是因排序的数据很大,一次不能容纳全部...
    蚁前阅读 3,653评论 0 52
  • 概述:排序有内部排序和外部排序,内部排序是数据记录在内存中进行排序,而外部排序是因排序的数据很大,一次不能容纳全部...
    每天刷两次牙阅读 2,518评论 0 15
  • 1.插入排序—直接插入排序(Straight Insertion Sort) 基本思想: 将一个记录插入到已排序好...
    依依玖玥阅读 336评论 0 2