题意:
给你n个数,让你找有多少个(i,j,k),使得i<j<k满足a[i]^a[j]<a[j]^a[k]。
题解:
首先考虑a[i]和a[k],将他们都转换成二进制,对于a[i]和a[k],我们用Bi[p]表示二进制下的a[i]的第p位。考虑a[i]和a[k]二进制不同的最高位,这里假设为p,如果Bi[p]=0,Bk[p]=1,那么Bj[p]要为0,才能使得a[i]^a[j]<a[j]^a[k]。(因为p前面的位相同,只有亦或后的第p位k为1,i为0就行了)比如a[i]=5,a[k]=6,那么二进制下不同的最高为就是2,那么a[j]可以为0,1,5,13等,只要Bj[2]=0就行了。同理如果Bi[p]=1,Bk[p]=0,那么Bj[p]要为1。
知道这个性质后,现在我们就可以对每一个数进行枚举了。这里我们从前往后枚举每一个a[k],将1~(k-1)的数都全部插到trie树里面,并且开一个数组num[30][2]记录第p位为0和为1的数有多少个。
然后在插入过程中,在trie树里面找有多少个i和j,j的个数就是num[p][Bk[p]^1]],i的个数就是当前cnt[x][Bk[p]^1]](因为这样计算保证了a[i]和a[k]的二进制前缀相同,就该位不同),但是cnt[x][Bk[p]^1]]中又包括有j的个数,所以这里计算要注意一下,具体看代码。
这里有一部分i,j没有保证i<j,因为a[j]可能是在a[i]前面的数,所以在对于每一个新插入的数,都要保存一下这个数被当成a[i]时,有多少个a[j]在他前面,就是num[p][Bk[p]]]-cnt[x][Bk[p]],然后后面的数在计算时就要减掉。(具体的话举几个数模拟一下就知道了)
1 #include<bits/stdc++.h> 2 #define F(i,a,b) for(int i=a;i<=b;++i) 3 using namespace std; 4 5 const int M=5e5+7; 6 int t,n,a[M],num[31][2],s[40]; 7 long long ans; 8 9 struct Trie 10 { 11 static const int N=5e6+7,tyn=2; 12 int tr[N][tyn],cnt[N],tot;long long ext[N]; 13 void nw(){cnt[++tot]=0,ext[tot]=0,memset(tr[tot],0,sizeof(tr[tot]));} 14 void init(){tot=-1,nw();} 15 void insert(int *s,int x=0){ 16 for(int i=0,w;i<30;i++) 17 { 18 if(!tr[x][w=s[i]])nw(),tr[x][w]=tot; 19 if(tr[x][w^1]) 20 { 21 int nxt=tr[x][w^1]; 22 ans+=1ll*cnt[nxt]*(cnt[nxt]-1)>>1; 23 ans+=1ll*(num[i][w^1]-cnt[nxt])*cnt[nxt]-ext[nxt]; 24 } 25 x=tr[x][w],cnt[x]++,ext[x]+=num[i][w]-cnt[x]; 26 } 27 } 28 }trie; 29 30 int main() 31 { 32 scanf("%d",&t); 33 while(t--) 34 { 35 scanf("%d",&n); 36 F(i,1,n)scanf("%d",a+i); 37 trie.init(),ans=0; 38 F(i,1,n) 39 { 40 for(int j=29;j>=0;a[i]>>=1,j--) 41 { 42 s[j]=a[i]&1; 43 num[j][a[i]&1]++; 44 } 45 trie.insert(s); 46 } 47 printf("%lld ",ans); 48 } 49 return 0; 50 }