算是一道比较全面的模板题了吧,需要注意的是:查找比x小的元素个数时x不一定在Treap中,解决办法是插入x,查询,再删除x。
1 #include <iostream> 2 #include <cstring> 3 #include <cstdlib> 4 #include <cstdio> 5 using namespace std; 6 7 struct Node 8 { 9 Node * ch[2]; 10 int v, r, size; 11 int cmp( int x ) 12 { 13 if ( x == v ) return -1; 14 return x < v ? 0 : 1; 15 } 16 void maintain() 17 { 18 size = 1; 19 if ( ch[0] != NULL ) size += ch[0]->size; 20 if ( ch[1] != NULL ) size += ch[1]->size; 21 } 22 }; 23 24 void rotate( Node * & o, int d ) 25 { 26 Node * k = o->ch[d ^ 1]; 27 o->ch[d ^ 1] = k->ch[d]; 28 k->ch[d] = o; 29 o->maintain(); 30 k->maintain(); 31 o = k; 32 } 33 34 void insert( Node * & o, int x ) 35 { 36 if ( o == NULL ) 37 { 38 o = new Node(); 39 o->ch[0] = o->ch[1] = NULL; 40 o->v = x; 41 o->r = rand(); 42 o->size = 1; 43 } 44 else 45 { 46 int d = o->cmp(x); 47 insert( o->ch[d], x ); 48 if ( o->ch[d]->r > o->r ) 49 { 50 rotate( o, d ^ 1 ); 51 } 52 else 53 { 54 o->maintain(); 55 } 56 } 57 } 58 59 void remove( Node * & o, int x ) 60 { 61 int d = o->cmp(x); 62 if ( d == -1 ) 63 { 64 if ( o->ch[0] != NULL && o->ch[1] != NULL ) 65 { 66 int dd = ( o->ch[0]->r > o->ch[1]->r ? 1 : 0 ); 67 rotate( o, dd ); 68 remove( o->ch[dd], x ); 69 } 70 else 71 { 72 Node * u = o; 73 if ( o->ch[0] == NULL ) o = o->ch[1]; 74 else o = o->ch[0]; 75 delete u; 76 } 77 } 78 else 79 { 80 remove( o->ch[d], x ); 81 } 82 if ( o != NULL ) o->maintain(); 83 } 84 85 int ranker( Node * o, int x, int sum ) 86 { 87 int d = o->cmp(x); 88 if ( d == -1 ) 89 { 90 return sum + ( o->ch[0] == NULL ? 0 : o->ch[0]->size ); 91 } 92 else if ( d == 0 ) 93 { 94 return ranker( o->ch[0], x, sum ); 95 } 96 else 97 { 98 int tmp = ( o->ch[0] == NULL ? 0 : o->ch[0]->size ); 99 return ranker( o->ch[1], x, sum + tmp + 1 ); 100 } 101 } 102 103 int kth( Node * o, int k ) 104 { 105 int tmp = ( o->ch[0] == NULL ? 0 : o->ch[0]->size ); 106 if ( k == tmp + 1 ) return o->v; 107 else if ( k < tmp + 1 ) return kth( o->ch[0], k ); 108 else return kth( o->ch[1], k - tmp - 1 ); 109 } 110 111 int find( Node * o, int x ) 112 { 113 if ( o == NULL ) return 0; 114 int d = o->cmp(x); 115 if ( d == -1 ) return 1; 116 return find( o->ch[d], x ); 117 } 118 119 int main () 120 { 121 Node * root = NULL; 122 char op[2]; 123 int q, num; 124 scanf("%d", &q); 125 while ( q-- ) 126 { 127 scanf("%s%d", op, &num); 128 if ( op[0] == 'I' ) 129 { 130 if ( !find( root, num ) ) 131 { 132 insert( root, num ); 133 } 134 } 135 else if ( op[0] == 'D' ) 136 { 137 if ( find( root, num ) ) 138 { 139 remove( root, num ); 140 } 141 } 142 else if ( op[0] == 'K' ) 143 { 144 int s = ( root == NULL ? 0 : root->size ); 145 if ( num > s ) 146 { 147 printf("invalid "); 148 } 149 else 150 { 151 printf("%d ", kth( root, num )); 152 } 153 } 154 else 155 { 156 if ( find( root, num ) ) 157 { 158 printf("%d ", ranker( root, num, 0 )); 159 } 160 else 161 { 162 insert( root, num ); 163 printf("%d ", ranker( root, num, 0 )); 164 remove( root, num ); 165 } 166 } 167 } 168 return 0; 169 }