思路
- 不用传统排序算法,如何从小到大排序一个数组?
- 把原数组的元素放到最大堆中,不断取出根节点,从后往前赋值到原数组
堆的实现
- 完全二叉树:二叉树,除了最后一层外,其他层的节点数都是最大值,最后一层所有的节点都在左侧
- 堆:是一个完全二叉树,任何一个节点都不大于它的父节点的堆是最大堆
- 堆的存储
- 利用数组,左节点的索引是父节点的2倍
- 堆的维护
- Shift Up:在数组末尾加入新元素,将其调整到合适的位置(和父节点比较,交换)
- Shift Down:取出根节点元素(优先级最大),将末尾元素放到根节点,调整元素位置(和子节点中较大的交换)
- 堆化(Heapify)
- 给定一个数组,让数组的排列形成一个堆
- 第一个非叶子节点的索引号 = 末尾元素的索引号/2
- 所有叶子节点本身构成最大堆
- 对非叶子节点进行Shift Down
Heap.h
1 #include <algorithm> 2 #include <cassert> 3 4 using namespace std; 5 6 template<typename Item> 7 class MaxHeap{ 8 private: 9 Item *data; 10 int count; 11 int capacity; 12 13 // 比较新加入的子节点和父节点大小并交换,直到最顶 14 // O(logn) 15 void shiftUp(int k){ 16 while( k > 1 && data[k/2] < data[k] ){ 17 swap(data[k/2],data[k]); 18 k /= 2; 19 } 20 } 21 22 // 比较两个子节点,和大的交换,直到最底 23 // 复杂度O(logn) 24 void shiftDown(int k){ 25 // 是否有左孩子 26 while( 2*k <= count ){ 27 int j = 2*k; 28 // 是否有右孩子&&右孩子是否比左孩子大 29 if( j+1 <= count && data[j+1] > data[j]) j++; 30 if( data[k] >= data[j]) break; 31 swap(data[k] , data[j]); 32 k = j; 33 } 34 } 35 public: 36 // 构造一个空堆,可容纳capacity个元素 37 // 时间复杂度O(nlogn) 38 MaxHeap(int capacity){ 39 data = new Item[capacity+1]; 40 count = 0; 41 this -> capacity = capacity; 42 } 43 44 // 通过一个给定数组创建最大堆(堆化) 45 // 时间复杂度O(n) 46 MaxHeap(Item arr[], int n){ 47 data = new Item[n+1]; 48 capacity = n; 49 for( int i = 0 ; i < n ; i ++ ) 50 data[i+1] = arr[i]; 51 count = n; 52 53 // 堆的第一个非叶子节点的索引为count/2 54 // 从这个节点开始,逐个shiftDown到最顶 55 for( int i = count/2 ; i >= 1 ; i -- ) 56 shiftDown(i); 57 } 58 ~MaxHeap(){ 59 delete[] data; 60 } 61 // 返回堆中元素个数 62 int size(){ 63 return count; 64 } 65 // 是否为空堆 66 bool isEmpty(){ 67 return count == 0; 68 } 69 // 插入新元素,shiftUp 70 void insert(Item item){ 71 assert( count +1 <= capacity ); 72 data[count+1] = item; 73 shiftUp(count+1); 74 count ++; 75 } 76 // 从最大堆取出堆顶元素,shiftDown 77 Item extractMax(){ 78 assert( count > 0 ); 79 Item ret = data[1]; 80 swap( data[1] , data[count] ); 81 count --; 82 shiftDown(1); 83 return ret; 84 } 85 // 获取最大堆的堆顶元素 86 Item getMax(){ 87 assert( count > 0 ); 88 return data[1]; 89 } 90 };
main.cpp
1 #include <iostream> 2 #include <algorithm> 3 #include "Student.h" 4 #include "Heap.h" 5 #include "SortTestHelper.h" 6 #include "MergeSort.h" 7 #include "QuickSort.h" 8 9 using namespace std; 10 11 12 // 先将元素添加到堆中再取出,即完成排序 13 // 添加和取出的复杂度均为O(nlogn),故整体复杂度为O(nlogn) 14 15 template<typename T> 16 void heapSort1(T arr[], int n){ 17 18 MaxHeap<T> maxheap = MaxHeap<T>(n); 19 for( int i = 0 ; i < n ; i ++ ) 20 maxheap.insert(arr[i]); 21 22 for( int i = n-1 ; i >= 0 ; i--) 23 arr[i] = maxheap.extractMax(); 24 } 25 26 template<typename T> 27 void heapSort2(T arr[], int n){ 28 MaxHeap<T> maxheap = MaxHeap<T>(arr,n); 29 for( int i = n-1 ; i >= 0 ; i--) 30 arr[i] = maxheap.extractMax(); 31 } 32 33 // 优化的shiftDown过程,用赋值取代swap,从0开始索引 34 template<typename T> 35 void __shiftDown2(T arr[], int n, int k){ 36 // k为起始位置 37 T e = arr[k]; 38 // 左孩子不越界 39 while( 2*k+1 < n ){ 40 int j = 2*k+1; 41 // 右孩子比左孩子大 42 if( j + 1 < n && arr[j+1] > arr[j] ) 43 j += 1; 44 if( e >= arr[j] ) break; 45 // 和左、右孩子中较大的交换 46 arr[k] = arr[j]; 47 // k向下传递 48 k = j; 49 } 50 // 将堆头元素放到堆尾 51 arr[k] = e; 52 } 53 54 // 原地堆排序,直接在原数组上进行,不需额外空间 55 template<typename T> 56 void heapSort(T arr[], int n){ 57 // 堆化 58 // 从0开始索引 59 // 第一个非叶子节点索引 = (最后一个元素索引-1)/2 60 for( int i = (n-1-1)/2 ; i >= 0 ; i-- ) 61 __shiftDown2(arr,n,i); 62 // 不断把堆顶元素换到最后,完成排序 63 for( int i = n-1; i>0 ; i -- ){ 64 swap( arr[0] , arr[i] ); 65 __shiftDown2(arr, i, 0); 66 } 67 } 68 69 int main(){ 70 int n = 10000; 71 //int *arr1 = SortTestHelper::generateNearlyOrderedArray(n,10); 72 int *arr1 = SortTestHelper::generateRandomArray(n,0,10); 73 int *arr2 = SortTestHelper::copyIntArray(arr1, n); 74 //SortTestHelper::testSort("Insertion Sort",insertionSort,arr1,n); 75 //SortTestHelper::testSort("Merge Sort",mergeSort,arr2,n); 76 SortTestHelper::testSort("Heap Sort",heapSort,arr1,n); 77 SortTestHelper::testSort("Quick Sort 3",quickSort3,arr2,n); 78 79 delete[] arr1; 80 delete[] arr2; 81 82 return 0; 83 }
应用
- 优先队列:出队顺序和入队顺序无关,和优先级有关,优先级最高的出列,队列中的成员是动态增减的
- 操作系统调度任务,先执行哪个
- 不同上网用户对同一个服务器发出请求,服务器先回应哪个
- 游戏角色攻击视野中的敌人,先攻击哪个(最强的?血最少的?...)
- 在N个元素中选出前M个元素
- 排序的话是NlogN,利用优先队列可达到NlogM(N越大,效率提升越明显)
- 实现优先队列
- 普通数组:入队O(1),出队O(n)
- 顺序数组:入队O(n),出队O(1)
- 堆:入队O(lgn),出队O(lgn)
- 对于总共N个请求,使用数组,最差情况均为O(n^2),使用堆为O(nlgn)
总结
- 堆的存储:在数组中,下标从1开始
- shiftUp干了什么:在堆尾部新插入元素后,调整堆以维护堆的定义
- shiftDown干了什么:从堆头取出一个元素后,将堆尾元素放到堆头,调整堆以维护堆的定义
- 直接建堆和堆化的区别:直接建堆是一个个插入,堆化是在现拥有数组的基础上调整数据
- 如何优化:额外用了O(n)的空间,可改为原地排序