给定整数m以及n个数字A1,A2,...An,将数列A中所有元素两两异或,共能得到n(n-1)/2个结果,请求出这些结果中大于m的有多少个。
一看题目,感觉是trie树,也没搞清楚逻辑,就开始码代码,这样的结果注定是失败的!正确的做法是在纸上画清楚,每一步应该怎么做,应该怎么考虑边界条件,怎么搜索,拿最简单的例子测试,然后自己再想一些边界例子测试,最后才是码代码,除非你对这题很熟,或者感觉是水题,直接写也可以。反正最后不会做,看了一下大神代码,半天才理解,都是套路,其实这个性质一时半会分析不出来吧!
分析:暴力是不行的,1e5的数据范围,暴力肯定超时,大方向肯定是用trie树进行压缩,然后是查找,关键的问题是,给定数字a,我们需要寻找b的格式,使得a^b>m,怎么从trie树中寻找b的个数,就是解决这道题目的关键。
1. 我想说明下数据范围,n,m,Ai都是[1,1e5]的,(1 << 17)>1e5,所以一个数至少要17位来存储,所以trie树的节点个数就是1e5*17,这个不理解的话,仔细查看一下trie树的资料吧。
2. 然后是a^b>m,现在我们知道a和m,要查找b的个数,首先a和m可以简单的表示成17位二进制01的形式,然后查找。查找的时候,以m为导向,我们尽量确保a^b以后,如果m相应位置为0,a^b相应位置为1的,肯定比m大(肯定大需要按照这里要求的方式从高位到低位进行枚举),然后最后结果加上这些为1的即可!不管后面的位是什么情况,因为结果肯定是大于m的! 如果m相应的位置为1,我们需要a和b相应的位置不同,即一个为1,一个为0.到这里可能迷糊了,如果相同也能保证异或结果大于m啊,比如(m=01010,a=00110,这里简单起见,一共5位,从左到右编号,m的第一位为0,如果a^b以后这位为1,显然是满足的,最后结果直接加上即可,不管后面的位是什么情况;然后考虑m的第二位为1,这里我需要保证a^b结果,相应位为1,然后往后才能找到满足要求异或结果大于m的,当然,有人说结果是0也可以啊,但是前一种情况已经把这种情况考虑进去了,这里只需要考虑这里的结果位为1即可,也就是a和b相应的二进制的这一位不同)
3. 第二点有点难懂,下面可以结合代码理解。我这里还要说明,什么时候统计结果,统计结果就是a和b的异或结果大于m,也就是m相应位置为1,结果就是异或结果相应位置为1,也就是异或结果尽量保证m二进制位上为1的位置尽量为1,(尽量为1的意思是:不一定一定为1,考虑上面的给的例子的第一种情况)。然后为0的位置至少有一个位置为1.最后的结果要用long long来存储,n(n-1)/2,int可能会溢出。最后,(a,b)和(b,a)算一种情况,所以结果需要除以2.
友情提示:跟1异或,相当于该位取反。a^b=c,有a^c=b.
下面是牛客网上抄的,比我讲的清楚。
异或那道题可以把每个数的二进制位求出来,用一个字典树维护,然后遍历每一个数按位贪心,比如这一位m是1,遍历的这个数这一位是0,那么和他异或的数就必须是1,如果这一位m是0,要大于m的话异或和的这一位可以是1也可以是零,ans加上之前维护的二进制位加上使这一位为1的数在字典树中查询有多少个数满足这个前缀的条件,然后在令这一位的异或和为0,继续向下遍历,最后的答案除以2.
贴上大神的代码,膜拜一下!orz。
放几个我认为有帮助的题目吧:
1. http://codeforces.com/problemset/problem/282/E 这个相关度最高,trie+xor
2. http://codeforces.com/contest/706/problem/D 这个也是trie+xor
3. http://codeforces.com/contest/714/problem/C 这个只有trie,这个题有点意思,还有不用trie树的简单做法。
4. https://threads-iiith.quora.com/Tutorial-on-Trie-and-example-problems 关于trie树的一点知识点吧!
1 #include <cstdio> 2 #include <cstring> 3 4 const int N = 100010; 5 6 int a[N]; 7 8 struct node { 9 int count; 10 int next[2]; 11 }p[N*17], root; 12 13 int cnt = 0; 14 void insert(int *a, int len) { 15 int now = 0; 16 for (int i = 0; i < len; ++i) { 17 if (p[now].next[a[i]] == -1) { 18 cnt++; 19 p[cnt].next[0] = p[cnt].next[1] = -1; 20 p[cnt].count = 0; 21 p[now].next[a[i]] = cnt; 22 } 23 now = p[now].next[a[i]]; 24 p[now].count++; 25 } 26 } 27 28 typedef long long LL; 29 int query(int *a, int *b, int len) { 30 int now = 0; 31 int ret = 0; 32 for (int i = 0; now != -1 && i < len; ++i) { 33 if (b[i] == 0) { 34 if (p[now].next[1^a[i]] != -1) ret += p[p[now].next[1^a[i]]].count; 35 now = p[now].next[a[i]]; 36 } 37 else { 38 now = p[now].next[1^a[i]]; 39 } 40 } 41 return ret; 42 } 43 44 45 int main() { 46 int n, m; 47 while (scanf("%d%d", &n, &m) == 2) { 48 cnt = 0; 49 p[0].next[0] = p[0].next[1] = -1; 50 p[0].count = 0; 51 for (int i = 0; i < n; ++i) { 52 scanf("%d", &a[i]); 53 int tmp[18]; 54 for (int j = 0; j < 18; ++j) 55 tmp[j] = (a[i] >> (17 - j)) & 1; 56 insert(tmp, 18); 57 //for (int i = 0; i < 30; ++i) printf("%d ", p[i].count); 58 //puts("----"); 59 } 60 int kk[18]; 61 for (int j = 0; j < 18; ++j) 62 kk[j] = (m >> (17 - j)) & 1; 63 LL ret = 0; 64 for (int i = 0; i < n; ++i) { 65 int tmp[18]; 66 for (int j = 0; j < 18; ++j) 67 tmp[j] = (a[i] >> (17 - j)) & 1; 68 ret += query(tmp, kk, 18); 69 //printf("%d ", ret); 70 } 71 printf("%lld ", ret / 2); 72 } 73 return 0; 74 }
之前不小心加上了gist的链接,所以打开很慢!
摘自牛客网:
刚刚听到另外一个方法...建好树之后,把m转成二进制,如果m当前位是0,直接把经过左右节点的数的个数相乘;如果m当前位是1,就分别从左右节点的子节点里选一个分支进行组合,递归调用,结果相加......
dfs里面的第一个if条件的原因,如果左右节点相等,i>j的情况不考虑,原因是:对于每一位,我们考虑(00,01,10,11),当当前节点相同的时候,01和10只需要计算一次即可,仔细想想!
1 /* 2 ID: y1197771 3 PROG: test 4 LANG: C++ 5 */ 6 #include<bits/stdc++.h> 7 #define pb push_back 8 #define FOR(i, n) for (int i = 0; i < (int)n; ++i) 9 #define dbg(x) cout << #x << " at line " << __LINE__ << " is: " << x << endl 10 typedef long long ll; 11 using namespace std; 12 typedef pair<int, int> pii; 13 const int maxn = 1e5 + 10; 14 struct node { 15 int next[2]; 16 int c; 17 node() { 18 memset(next, 0, sizeof next); 19 c = 0; 20 } 21 } A[maxn * 18]; 22 int num; 23 int n, m; 24 int tag[18]; 25 void insert(int x) { 26 int u = 0, cur = 0; 27 for (int i = 17; i >= 0; i--) { 28 cur = ((1 << i) & x) > 0; 29 if(!A[u].next[cur]) { 30 A[u].next[cur] = ++num; 31 } 32 u = A[u].next[cur]; 33 //cout << u << " " << x << endl; 34 A[u].c++; 35 } 36 } 37 ll dfs(int cur, int l, int r) { 38 39 if(cur < 0) return 0; 40 ll res = 0; 41 for (int i = 0; i <= 1; i++) { 42 for (int j = 0; j <= 1; j++) { 43 if(l == r && i > j) continue; 44 if(!A[A[l].next[i] ].c || !A[A[r].next[j] ].c) continue; 45 if((i ^ j) > tag[cur]) { 46 res += 1ll * A[A[l].next[i] ].c * A[A[r].next[j] ].c; 47 } else if((i ^ j) == tag[cur]) { 48 res += dfs(cur - 1, A[l].next[i], A[r].next[j]); 49 } 50 } 51 } 52 //cout << cur << " " << l << " " << r << " " << res << endl; 53 return res; 54 } 55 void solve() { 56 num = 0; 57 memset(A, 0, sizeof A); 58 cin >> n >> m; 59 int x; 60 for (int i = 0; i < n; i++) { 61 cin >> x; 62 insert(x); 63 } 64 for (int i = 17; i >= 0; i--) { 65 tag[i] = ((1 << i) & m) > 0; 66 } 67 printf("%I64d ", dfs(17, 0, 0)); 68 } 69 int main() { 70 freopen("data", "r", stdin); 71 //freopen("test.out", "w", stdout); 72 int _ = 20; 73 while(_--) 74 solve(); 75 return 0; 76 }