1 /******************************************************************** 2 created: 2014/04/29 11:35 3 filename: nth_element.cpp 4 author: Justme0 (http://blog.csdn.net/justme0) 5 6 purpose: nth_element 7 *********************************************************************/ 8 9 #include <cstdio> 10 #include <cstdlib> 11 #include <cstring> 12 13 typedef int Type; 14 15 template <class T> 16 inline T * copy_backward(const T *first, const T *last, T *result) { 17 const ptrdiff_t num = last - first; 18 memmove(result - num, first, sizeof(T) * num); 19 return result - num; 20 } 21 22 /* 23 ** 将 value 插到 last 前面(不包括 last)的区间 24 ** 此函数保证不会越界(主调函数已判断),因此以 unguarded_ 开头 25 */ 26 template <class RandomAccessIterator, class T> 27 void unguarded_linear_insert(RandomAccessIterator last, T value) { 28 RandomAccessIterator next = last; 29 --next; 30 while(value < *next) { 31 *last = *next; 32 last = next; 33 --next; 34 } 35 *last = value; 36 } 37 38 /* 39 ** 将 last 处的元素插到[first, last)的有序区间 40 */ 41 template <class RandomAccessIterator> 42 void linear_insert(RandomAccessIterator first, RandomAccessIterator last) { 43 Type value = *last; 44 if (value < *first) { // 若尾比头小,就将整个区间一次性向后移动一个位置 45 copy_backward(first, last, last + 1); 46 *first = value; 47 } else { 48 unguarded_linear_insert(last, value); 49 } 50 } 51 52 template <class RandomAccessIterator> 53 void insertion_sort(RandomAccessIterator first, RandomAccessIterator last) { 54 if (first == last) { 55 return ; 56 } 57 58 for (RandomAccessIterator ite = first + 1; ite != last; ++ite) { 59 linear_insert(first, ite); 60 } 61 } 62 63 template <class T> 64 inline const T & median(const T &a, const T &b, const T&c) { 65 if (a < b) { 66 if (b < c) { 67 return b; 68 } else if (a < c) { 69 return c; 70 } else { 71 return a; 72 } 73 } else if (a < c) { 74 return a; 75 } else if (b < c) { 76 return c; 77 } else { 78 return b; 79 } 80 } 81 82 template <class ForwardIterator1, class ForwardIterator2> 83 inline void iter_swap(ForwardIterator1 a, ForwardIterator2 b) { 84 Type tmp = *a; // 源码中的 T 由迭代器的 traits 得来,这里简化了 85 *a = *b; 86 *b = tmp; 87 } 88 89 /* 90 ** 设返回值为 mid,则[first, mid)中迭代器指向的值小于等于 pivot; 91 ** [mid, last)中迭代器指向的值大于等于 pivot 92 ** 这是 STL 内置的算法,会用于 nth_element, sort 中 93 ** 笔者很困惑为什么不用 partition 94 */ 95 template <class RandomAccessIterator, class T> 96 RandomAccessIterator unguarded_partition(RandomAccessIterator first, RandomAccessIterator last, T pivot) { 97 while(true) { 98 while (*first < pivot) { 99 ++first; 100 } 101 --last; 102 while (pivot < *last) { // 若 std::partition 的 pred 是 IsLess(pivot),这里将是小于等于 103 --last; 104 } 105 if (!(first < last)) { // 小于操作只适用于 random access iterator 106 return first; 107 } 108 iter_swap(first, last); 109 ++first; 110 } 111 } 112 113 template <class RandomAccessIterator> 114 void nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last) { 115 while (last - first > 3) { 116 RandomAccessIterator cut = unguarded_partition(first, last, Type(median( 117 *first, 118 *(first + (last - first) / 2), 119 *(last - 1)))); 120 if (cut <= nth) { 121 first = cut; 122 } else { 123 last = cut; 124 } 125 } 126 insertion_sort(first, last); 127 } 128 129 130 int main(int argc, char **argv) { 131 int arr[] = {22, 30, 30, 17, 33, 40, 17, 23, 22, 12, 20}; 132 int size = sizeof arr / sizeof *arr; 133 134 nth_element(arr, arr + 5, arr + size); 135 136 for (int i = 0; i < size; ++i) { 137 printf("%d ", arr[i]); // 20 12 22 17 17 22 23 30 30 33 40 138 } 139 printf(" "); 140 141 system("PAUSE"); 142 return 0; 143 }