题意
给定(n)个带权点,第(i)个点的权值为(w_i),任意两点间都有边,边权为两端点权的异或值,求最小生成树边权和,以及方案数(mod 10^9 + 7)
(n leq 10^5,W = max(w_i) leq 2^{30})
题解
考虑按位贪心,我们从高到低考虑二进制第k位。每次把当前点集(S)分成第(k)位为(0)和第(k)位为(1)的两个集合,记为(S_0, S_1)。
我们递归下去把这两个集合连成生成树,然后再找一条最小的跨集合的边把这两个集合连通。
考虑这么做为啥对:假设有两条跨集合的边,我删去一条,树变成两个部分。然后任意找到一条集合内部边使集合(S)连通(既然有跨集合的边存在,我们一定能找到这样的一条边),这样显然更优。
然后考虑问题:找到(xin S_0,yin S_1,x ext{ xor } y)最小。
这个用类似线段树合并的方法:每次两个结点同时往下走,尽量往一边走。如果能同时往(0/1)走,都走一遍,复杂度是对的,每次合并复杂度是子树大小。考虑trie树上一个点只有(O(log W))个祖先,一共只有(O(n log W))个结点,所以复杂度(O(n log ^2 W))
我们再来考虑方案。叶子结点时假设大小为(n),也就是说(n)个点都是这个权值,生成树的方案数(n^{n-2})(由prufer序列得)。非叶子结点时,方案是分成的两个集合的方案乘最后连边方案。连边会对应trie树上多对叶子((u, v))(这些对结点异或起来都是最小的),若叶子(u)上放的数个数用(cnt[u])表示,连边方案就是(sum_{(u,v)} cnt[u]*cnt[v])。
P.S.:快速幂写错了调了好久,差评
#include <algorithm>
#include <cstdio>
using namespace std;
typedef long long ll;
char gc() {
static char buf[1 << 20], * S, * T;
if(S == T) {
T = (S = buf) + fread(buf, 1, 1 << 20, stdin);
if(S == T) return EOF;
}
return *S ++;
}
template<typename T> void read(T &x) {
x = 0; char c = gc(); bool bo = 0;
for(; c > '9' || c < '0'; c = gc()) bo |= c == '-';
for(; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15);
if(bo) x = -x;
}
const int N = 1e5 + 10;
const int mo = 1e9 + 7;
int n, id = 1, ch[N * 30][2], cnt[N * 30], w[N * 30];
void insert(int x) {
int u = 1;
for(int i = 29; ~ i; i --) {
int y = x >> i & 1;
if(!ch[u][y]) {
ch[u][y] = ++ id;
w[id] = y << i;
}
u = ch[u][y];
}
cnt[u] ++;
}
ll ans;
int ans2, tot2, tot = 1;
int qpow(int a, int b) {
int ans = 1;
for(; b >= 1; b >>= 1, a = (ll) a * a % mo)
if(b & 1) ans = (ll) ans * a % mo;
return ans;
}
void merge(int u, int v, int now) {
now ^= w[u] ^ w[v];
if(cnt[u] && cnt[v]) {
if(now < ans2) { ans2 = now; tot2 = 0; }
if(now == ans2) tot2 = (tot2 + (ll) cnt[u] * cnt[v]) % mo;
return ;
}
bool tag = 0;
if(ch[u][0] && ch[v][0]) merge(ch[u][0], ch[v][0], now), tag = 1;
if(ch[u][1] && ch[v][1]) merge(ch[u][1], ch[v][1], now), tag = 1;
if(tag) return ;
if(ch[u][0] && ch[v][1]) merge(ch[u][0], ch[v][1], now);
if(ch[u][1] && ch[v][0]) merge(ch[u][1], ch[v][0], now);
}
bool solve(int u) {
if(!u) return 0;
if(cnt[u]) {
if(cnt[u] > 2) tot = (ll) tot * qpow(cnt[u], cnt[u] - 2) % mo;
return 1;
}
bool s = solve(ch[u][1]) & solve(ch[u][0]);
if(s) {
ans2 = 2e9 + 10; tot2 = 1;
merge(ch[u][0], ch[u][1], 0);
ans += ans2; tot = (ll) tot * tot2 % mo;
}
return 1;
}
int main() {
read(n);
for(int i = 1; i <= n; i ++) {
int x; read(x); insert(x);
}
solve(1);
printf("%lld
%d
", ans, tot);
return 0;
}