前言
很套路的一道题。
题目
链接就算了吧。
题目大意:给 (n) 个数字 (a_i),将两两数字的异或和进行排序,询问前 (m) 大的和,对 (10^9+7) 取模。
((a_i,a_j)) 和 ((a_j,a_i)) 算一种组合方式。
(nle 5 imes 10^4,0le a_i,mle 10^9.)
讲解
为了方便讲解,我们将两个数的异或和称作异或数,(a_i) 称为原数。
天马行空的想象
首先我们看到异或,我们就可以想到字典树。
然后问前 (m) 大,最原始的想法当然是把所有异或数都塞到优先队列里面去,然后取 (m) 个。
然后我想到了异或粽子,发现其实并不好转移到次大值,而且 (m) 没给,我们默认是 (n^2) 级别(后来发现确实是)。
正解
思路
容易想到二分。
直接二分第 (m) 大的数的大小 (val),每次二分只需要统计有多少个数大于等于 $ val$,记为 (c),最后计算这些数的和即可。由于 (c) 可能大于 (m),减去最小的异或数即可,容易发现那些异或数都是 (val)。
时间复杂度 (O(nlog_2^2(a_i)))。
实现
part1 二分异或数的个数
我们考虑将所有原数依次插入字典树,插入时经过的字典树上的每个点都用 (cnt) 记录一下,就可以知道从这个点往后走有多少个原数了。
每插入一个原数就统计与当前原数异或起来大于等于 (val) 的异或数的个数,我们可以直接在字典树上 (dfs)。
具体的,设之前走过的点累计起来为 (s),我们走向当前位置 (i) 异或和为 (1) 的点,如果走过去权值即为 (s+2^i)。
-
如果走过去后已经大于等于 (val),后面不管怎么走都一定大于等于 (val),所以直接累加这个点的 (cnt) 即可,当然另外一边要接着走下去,总共只需要走一边。
-
否则如果走了这边 (s+2^i) 依然小于 (val),另外一边得到的最大权值也只能是 (s+sum_{j=0}^{i-1}2^j=s+2^i-1),不可能达到 (val),所以另外一边不需要走,也只需要走一边。
故时间复杂度得到保证,为 (O(log_2(a_i)))。算上外面的二分和里面的建树,这个部分是 (O(nlog_2^2(a_i))) 的。
为了卡常,我们可以只建一次树,当然这样计算 ((a_i,a_j)) 与 ((a_j,a_i)) 会被记录两次,所以最后需要除以二。
part2 求和
与 ( t part1) 的思想类似,只不过建树的时候我们多维护一个信息:到这个点的数每一位上的 (1) 有多少个,我们记为 (yi_i(iin[0,29]))。
在 ( t part1) 中累加 (cnt) 的地方我们用 (O(log_2(a_i))) 的时间算一下贡献即可。
这个部分没有二分,但是多了 (O(log_2(a_i))) 算贡献,所以时间复杂度还是 (O(nlog_2^2(a_i)))。
值得注意的是如果你二分的时候就求贡献,时间复杂度就会达到 (O(nlog_2^3(a_i))),是无法通过的。
所以你可以像我一样相似的函数写两份。
代码
我为试图阅读我代码的勇者附了注释,希望对你有所帮助。
丑陋的考场代码
//12252024832524
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std;
typedef long long LL;
const int MAXN = 50005;
const int MOD = 1e9 + 7;
int n,m,cc,ans;
int a[MAXN];
LL Read()
{
LL x = 0,f = 1;char c = getchar();
while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
return x * f;
}
TT void Put1(T x)
{
if(x > 9) Put1(x/10);
putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
if(x < 0) putchar('-'),x = -x;
Put1(x); if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}
int tot,cao;
struct node
{
int ch[2],cnt,yi[30];
}t[MAXN * 31];
int dfs1(int now,int x,int val,int s)
{
if(x < 0) return 0;
int to = val >> x & 1;
int fan = to ^ 1,ret = 0;
if(t[now].ch[fan]) //走这边是 s+2^i
{
if(s + (1 << x) >= cao) ret += t[t[now].ch[fan]].cnt;//如果大于等于val,可以直接加(代码中为cao)
else ret += dfs1(t[now].ch[fan],x-1,val,s+(1<<x));//否则只用走这边,另外一边达不到 (*1)
}
if(t[now].ch[to] && s + (1<<x)-1 >= cao) ret += dfs1(t[now].ch[to],x-1,val,s);//显然如果这里还有机会,那么 *1 不会执行。
return ret;
}
int check()
{
int c = 0;
for(int i = 1;i <= n;++ i) c += dfs1(0,29,a[i],0);
return c / 2;
}
//---------------------------违和的分割线-----------------------------
void jia(int x,int now)//用 O(log_2(a_i)) 的时间算对答案的贡献
{
for(int i = 29;i >= 0;-- i)
{
int fan = (x >> i & 1) ^ 1;
if(fan) ans = (ans + (1ll << i) * t[now].yi[i]) % MOD;
else ans = (ans + (1ll << i) * (t[now].cnt - t[now].yi[i])) % MOD;
}
}
int dfs2(int now,int x,int val,int s)//与上面类似
{
if(x < 0) return 0;
int to = val >> x & 1;
int fan = to ^ 1,ret = 0;
if(t[now].ch[fan])
{
if(s+(1<<x) >= cao)
{
jia(val,t[now].ch[fan]);
ret += t[t[now].ch[fan]].cnt;
}
else ret += dfs2(t[now].ch[fan],x-1,val,s+(1<<x));
}
if(t[now].ch[to] && s + (1<<x)-1 >= cao) ret += dfs2(t[now].ch[to],x-1,val,s);
return ret;
}
int dz[MAXN];
void solve()
{
for(int i = 0;i <= tot;++ i) t[i].ch[0] = t[i].ch[1] = t[i].cnt = 0; tot = 0;
for(int i = 1;i <= n;++ i)
{
int now = 0,len = 0;;
for(int k = 29;k >= 0;-- k)//剪枝,卡常卡常卡常
if(a[i] >> k & 1)
dz[++len] = k;
for(int j = 29;j >= 0;-- j)
{
int to = a[i] >> j & 1;
if(!t[now].ch[to]) t[now].ch[to] = ++tot;
now = t[now].ch[to];
t[now].cnt++;
for(int k = 1;k <= len;++ k)
t[now].yi[dz[k]]++;
}
cc += dfs2(0,29,a[i],0);
}
return;
}
int main()
{
freopen("xor.in","r",stdin);
freopen("xor.out","w",stdout);
n = Read(); m = Read();
for(int i = 1;i <= n;++ i) a[i] = Read();
//二分前建树卡常
for(int i = 1;i <= n;++ i)
{
int now = 0;
for(int j = 29;j >= 0;-- j)
{
int to = a[i] >> j & 1;
if(!t[now].ch[to]) t[now].ch[to] = ++tot;
now = t[now].ch[to];
t[now].cnt++;
}
}
int l = 1,r = (1 << 30) - 1,ret = 1;
while(l <= r)
{
int mid = (l+r) >> 1; cao = mid;
if(check() >= m) l = mid+1,ret = mid;
else r = mid-1;
}
cao = ret;
solve();
//注意减掉多余的部分
Put(((ans - 1ll * Max(0,cc-m) * ret) % MOD + MOD) % MOD);
return 0;
}
太懒了,所以没图。
我才不会告诉你这是我精简过后的代码。