源:CNUOJ-0384 http://oj.cnuschool.org.cn/oj/home/problem.htm?problemID=354
题目分析:当时拿到这道题第一个想法就是排序后n^2暴力枚举,充分利用好有序这一特性随时“短路”。
1 read(n);read(m); 2 int temp; 3 4 for(int i = 0; i < n; i++) 5 { 6 read(temp); 7 if(temp <= m) 8 { 9 p[tot].v = temp; 10 p[tot++].num = i; 11 } 12 } 13 14 sort(p, p + tot); 15 16 for(int i = 0; i < tot - 1; i++) 17 { 18 for(int j = i + 1; j < tot; j++) 19 { 20 if(p[i].v + p[j].v > m) break; 21 else if(p[i].num < p[j].num) ans++; 22 } 23 ans %= 1000000007; 24 } 25 26 printf("%d ", ans); 27 28 return 0;
然后很高兴的简单一测试就哭了……为什么答案始终不对呢……
如上图就是一个十分典型的例子(下标为序号),不难发现当我们手算时,我们是算的“2 + 1”,但排完序后计算机执行的是“1 + 2”,序号不符合题意就被砍掉了……
那么怎么解决呢?第一个想法便是优化一下 j 的循环,从头到尾全来一遍,保证不漏解:
1 read(n);read(m); 2 int temp; 3 4 for(int i = 0; i < n; i++) 5 { 6 read(temp); 7 if(temp <= m) 8 { 9 p[tot].v = temp; 10 p[tot++].num = i; 11 } 12 } 13 14 sort(p, p + tot); 15 16 for(int i = 0; i < tot; i++) 17 { 18 for(int j = 0; j < tot; j++) 19 { 20 if(i == j) continue; 21 if(p[i].v + p[j].v > m) break; 22 else if(p[i].num < p[j].num) ans++; 23 } 24 ans %= 1000000007; 25 } 26 27 printf("%d ", ans);
排序是nlogn,枚举是n^2,对于十万的数据量无法吐槽……只能有50分……
好的,接下来我们来分析一下这个算法干了些什么,比如对于3 2 1这个序列:
可以看到,1 -> 2, 2 -> 1 分别别计算过一次,其中红色的一次失败了,蓝色的成功了。(在红色中3 > 2,而在蓝色中是2 < 3满足题意)
也不难得出:每一对数字都被正着计算了一次,反着计算了一次;一次失败,一次成功。
那么:反正判断不判断都是ans++ 一次(仅成功一次,失败的时候不会更新),那为什么要判断呢?
于是,我们编出了下面没有判断编号只判断了大小的程序:
1 for(int i = 0; i < tot - 1; i++) 2 { 3 for(int j = i + 1; j < tot; j++) 4 { 5 if(p[i].v + p[j].v > m) break; 6 ans++; 7 } 8 ans %= 1000000007; 9 }
这个算法常数可以几乎砍掉一半多,但复杂度仍然是n^2的,它的表现……呵呵
进一步的优化就要从查询方法开始了:用十分常见的二分法效果怎么样呢?于是我们终于得到了标程:
1 #include <iostream> //第一题标程 2 #include <cstdio> 3 #include <cmath> 4 #include <algorithm> 5 using namespace std; 6 7 const int maxn = 100000 + 5; 8 int n, m; 9 10 int p[maxn]; 11 int ans = 0; 12 13 int tot = 0; 14 15 int A[maxn]; 16 17 void read(int& x) 18 { 19 x = 0; 20 char ch = getchar(); 21 int sig = 1; 22 while(!isdigit(ch)) 23 { 24 if(ch == '-') sig = -1; 25 ch = getchar(); 26 } 27 while(isdigit(ch)) 28 { 29 x = x * 10 + ch - '0'; 30 ch = getchar(); 31 } 32 x *= sig; 33 return ; 34 } 35 36 37 int main() 38 { 39 read(n);read(m); 40 int temp; 41 42 for(int i = 0; i < n; i++) 43 { 44 read(temp); 45 if(temp <= m) p[tot++] = temp; 46 } 47 48 sort(p, p + tot); 49 50 int L, R, M; 51 52 for(int i = 0; i < tot - 1; i++) 53 { 54 L = i; 55 R = tot - 1; 56 57 if(p[R] + p[i] <= m) 58 { 59 ans += R - i; 60 continue; 61 } 62 63 while(L < R) 64 { 65 M = L + R >> 1; 66 //printf("L:%d M:%d R:%d ", L, M, R); 67 if(p[i] + p[M] <= m) L = M + 1; 68 else R = M; 69 } 70 71 if(L - i > 1) ans += L - i - 1; 72 73 //cout << ans << endl; 74 ans %= 1000000007; 75 } 76 77 printf("%d ", ans); 78 79 return 0; 80 }
从n^2到nlogn是不小的进步,十万的数据专为nlogn而生:
这道题就这样被干掉了……
当然了,其实还可以用treap瞎搞啦:
1 #include <cstdio> 2 #include <iostream> 3 #include <algorithm> 4 #include <ctime> 5 using namespace std; 6 const int maxn = 100000 + 10; 7 struct Node{ 8 int r, v, s; 9 Node* ch[2]; 10 Node() {} 11 void maintain(){ 12 s = ch[0] -> s + ch[1] -> s + 1; 13 return ; 14 } 15 }nodes[maxn], *null = &nodes[0]; 16 int tot = 0, n, m, A[maxn]; 17 void read(int& x){ 18 x = 0; int sig = 1; char ch = getchar(); 19 while(!isdigit(ch)) { if(ch == '-') sig = -1; ch = getchar(); } 20 while(isdigit(ch)) x = 10 * x + ch - '0', ch = getchar(); 21 x *= sig; return ; 22 } 23 void rotate(Node* &o, int d){ 24 Node* k = o -> ch[d ^ 1]; 25 o -> ch[d ^ 1] = k -> ch[d]; 26 k -> ch[d] = o; 27 o -> maintain(); 28 k -> maintain(); 29 o = k; 30 return ; 31 } 32 void init(Node* &o, int v){ 33 o -> ch[0] = null; 34 o -> ch[1] = null; 35 o -> s = 1; 36 o -> r = rand(); 37 o -> v = v; 38 return ; 39 } 40 void insert(Node* &o, int v){ 41 if(o == null){ 42 o = &nodes[++ tot]; 43 init(o, v); 44 } 45 else{ 46 int d = v < o -> v ? 0 : 1; 47 insert(o -> ch[d], v); //你大爷 48 if(o -> ch[d] -> r > o -> r) rotate(o, d ^ 1); 49 } 50 o -> maintain(); 51 return ; 52 } 53 void remove(Node* &o, int v){ 54 if(o == null) return ; 55 if(o -> v == v){ 56 if(o -> ch[0] == null) o = o -> ch[1]; 57 else if(o -> ch[1] == null) o = o -> ch[0]; 58 else{ 59 int d = o -> ch[0] -> r < o -> ch[1] -> r ? 0 : 1; 60 rotate(o, d); 61 remove(o -> ch[d ^ 1], v); 62 } 63 } 64 if(o != null) o -> maintain(); 65 return ; 66 } 67 /*int find(Node* &o, int v){ 68 if(o == null) return -1; 69 if(o -> v == v){ 70 if(o -> ch[0] != null) return o -> ch[0] -> s; 71 else return 1; 72 } 73 int d = v < o -> v ? 0 : 1; 74 return find(o -> ch[d], v); 75 }//*/ 76 int rank(Node* &o, int v){ 77 if(o == null) return 0; 78 if(o -> v > v) return rank(o -> ch[0], v); 79 else return rank(o -> ch[1], v) + o -> ch[0] -> s + 1; 80 } 81 Node* root = null; 82 long long ans = 0; 83 void Init(){ 84 srand(time(0)); 85 null -> s = 0; 86 read(n); read(m); 87 for(int i = 0; i < n; i ++){ 88 read(A[i]); 89 ans += rank(root, m - A[i]); 90 //printf("%d ", rank(root, m - A[i])); 91 insert(root, A[i]); 92 } 93 return ; 94 } 95 96 void print(Node* &o){ 97 if(o == null) return ; 98 print(o -> ch[0]); 99 printf("%d ", o -> v); 100 print(o -> ch[1]); 101 return ; 102 } 103 104 int main(){ 105 Init(); 106 printf("%lld ", ans % 1000000007); 107 return 0; 108 } 109 /* 110 5 100000 111 1 2 3 4 5 112 */
小结:本题一定要好好读条件,仔细分析哪些可以优化。
考点:二分,坑题