最近复习算法相关的内容,在实现快排的时候遇到一些小问题,总结一下。

基本思想

快排的基本思想比较简单:

  • 选取数组一个元素作为主元。
  • 遍历一遍数组,将小于主元的元素放在数组左边,大于或等于的元素放在数组右边。
  • 分别递归排序数组左边和右边的元素。

算法核心的部分就是将数组分成2部分的分割算法,主要有2种方法,下面分别描述。

Lomuto分割

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
// len(input) >= 2
func partition1(input []int) int {
        k := 0
        right := len(input) - 1
        pivot := input[right]
        // 循环不变量
        // input [0, k) < pivot
        // input [k, i) >= pivot
        for i := 0; i < right; i++ {
                if input[i] < pivot {
                        if k != i {
                                input[k], input[i] = input[i], input[k]
                        }
                        k++
                }
        }
        input[k], input[right] = input[right], input[k]
        return k
}

这个是算法导论里介绍快排时候采用的算法,比较容易理解,但效率不是最高的,不过实现简单且不容易出错。这里的返回值是主元的位置,严格来说把数组分割成了3部分:左边小于主元,中间是主元,右边大于等于主元。 对应的快排驱动程序为:

1
2
3
4
5
6
7
func quickSort1(input []int) {
        if len(input) >= 2 {
                pivot := partition1(input)
                quickSort1(input[:pivot])
                quickSort1(input[pivot+1:])
        }
}

Hoare分割

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
// len(input) >= 2
func partition2(input []int) int {
        left := -1
        right := len(input)
        pivot := input[(right-1)/2]
        for {
                for left++; input[left] < pivot; left++ {
                }
                // [0, left) <= pivot, input[left] >= pivot

                for right--; input[right] > pivot; right-- {
                }
                // (right, len(input)) >= pivot, input[right] <= pivot

                // left == right || left = right+1
                if left >= right {
                        return right
                }
                input[left], input[right] = input[right], input[left]
        }
}

这个是原始快排算法的分割方式,看起来也是很简单但是里面有些地方容易犯错,导致算法进入死循环。跟第一种方法不同的是,这里是将数组分成2个部分,左边小于等于主元,右边大于等于主元。 可以证明下面的结论:

  • pivot的选取位置需要特殊考虑,如果是最后一个元素上面的实现会有问题
  • left和right不会越出数组的范围
  • 返回的分割位置right满足:0 <= right < len(input)-1
  • 最后数组满足:input[:right+1] <= input[right+1:]

算法驱动程序如下:

1
2
3
4
5
6
7
func quickSort2(input []int) {
        if len(input) >= 2 {
                pivot := partition2(input)
                quickSort2(input[:pivot+1])
                quickSort2(input[pivot+1:])
        }
}

参考资料