题目本身很有趣,但是被这个弱智一样的输出格式和数据范围卡了半天……这种屑出题人建议枪毙一下,qqqxx。
设 (a_i) 的二进制最高位为 (2^l),显然任意子集的异或和 (x) 的最高位都不会超过 (2^l)。
我们需要求的是 (E(x^k)), 将 (x) 二进制拆分后,可以得到 (Eleft((sum_{i=0}^l x_icdot 2^i)^k ight)),其中 (x_iin {0, 1}),表示 (x) 的第 (i) 位是否为 (1)。考虑 ((sum_{i=0}^l x_icdot 2^i)^k) 的组合意义,相当于从 (x) 的二进制表示中任意选出 (k) 个 (x_i=1) 的位置 (i),每种选法的权值为 (2^{sum i}),所有选法的权值和。
因此根据期望的线性性,可以枚举对应的 (k) 个位置,求出有多少个子集满足异或和中这些位置都是 (1)。只需将每个数的这 (k) 个位置提取出来,依次插入到线性基中去。假如最终可以找到至少一个子集满足条件(等价于无法将仅这 (k) 个位置全 (1) 的数插入线性基),则它的贡献为 (2^{sum i-|S|}),其中 (|S|) 表示线性基内元素个数。
如果每个枚举中暴力遍历 ([1, n]),显然会超时。但是注意到线性空间的任意一个子空间,它的基一定是原空间的基的子集。因此只需要保留原来的 (n) 个数中的一个线性基即可。
复杂度 (O(l^kcdot operatorname{poly}(l)cdot operatorname{poly}(k)))。
注意到题面里没有给出 (a) 的上界,但是我们可以根据 “答案不超过 (2^{63}-1)” 估算出,(a^k) 大约也在 (2^{64}) 范围内。
但是输出格式很卡,大概需要在统计答案的时候记录 (2^{sum i-|S|+k}) 的系数,然后暴力判断小数点后有多少位……
另外 (k=1) 时,(a_i) 可以达到 (2^{64}-1) 的级别,需要用无符号 (64) 位整数……
输出比题目本身难还行
#include <bits/stdc++.h>
#define R register
#define mp make_pair
#define ll unsigned long long
#define pii pair<int, int>
using namespace std;
const int mod = 998244353, N = 110000;
int n, k, lim, ans[100];
ll a[N], base[70];
vector<ll> bs;
map<ll, int> f;
inline int insrt(ll x) {
for (R int i = lim; ~i; --i) {
if (~x & (1ull << i)) continue;
if (!base[i]) return base[i] = x, 1;
x ^= base[i];
}
return 0;
}
inline int addMod(int a, int b) {
return (a += b) >= mod ? a - mod : a;
}
inline ll quickpow(ll base, ll pw) {
ll ret = 1;
while (pw) {
if (pw & 1) ret = ret * base % mod;
base = base * base % mod, pw >>= 1;
}
return ret;
}
template <class T>
inline void read(T &x) {
x = 0;
char ch = getchar(), w = 0;
while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
x = w ? -x : x;
return;
}
void dfs(int now, ll w, int sum) {
if (now == k) {
if (f.find(w) == f.end()) {
int num = 0;
memset(base, 0, sizeof (base));
for (R int i = 0, sz = bs.size(); i < sz; ++i)
num += insrt(bs[i] & w);
// for (auto &v : bs) num += insrt(v & w);
f[w] = insrt(w) ? 0 : num;
}
if (f[w]) ++ans[sum + k - f[w]];
return;
}
for (R int i = 0; i <= lim; ++i)
dfs(now + 1, w | (1ull << i), sum + i);
return;
}
int main() {
read(n), read(k);
for (R int i = 1; i <= n; ++i) {
read(a[i]);
while ((1ull << lim << 1) - 1 < a[i]) ++lim;
if (insrt(a[i])) bs.push_back(a[i]);
}
dfs(0, 0, 0);
ll s = 0, t = 0;
for (R int i = k; i <= (lim + 1) * k; ++i)
s += ans[i] * (1ull << (i - k));
for (R int i = 0; i < k; ++i)
t += ans[i] * (1 << i);
while (k && (~t & 1)) --k, t >>= 1;
printf("%.*Lf
", k, s + (long double) t / (1 << k));
return 0;
}