题目
给出(n,m_1,m_2,...,m_n),求(x_1 xor x_2 xor ... xor x_n=k (0 leq x_i leq m_i))的解的数量。二进制位数小于(32)。
分(tu)析(cao)
这其实是一道非常不错的题目,又感觉像rng_58
的题了。囧
算法
我们现在只考虑前(i)位(二进制)的状态,第(i+1 sim 32)位假设都已经填好了。这时,如果一个数(x_j)没有限制,也就是它的前(i)位可以任意地填(0)或(1),这个时候,我们可以将除这个数外的数任意填,设任意填完后的异或值为(b),那么取(x_j=k xor b),我们就得到了一个解。
这告诉我们,如果我们从高位到低位计算,如果有任意一个数(x_j)脱离了(m)的限制,那么就不再需要往下计算了,答案就是:除(x_j)外的数任意填,并且满足(b)的第(i)位等于(k)的第(i)位(因为这时的(x_j)的第(i)位已经为了脱离限制而填了(0))的方案数。那么怎么计算这个方案数呢,由于在(i)位的时候,可能有很多数同时脱离限制,所以我们可以枚举最后一个脱离限制的是哪个数,然后在这个数前面的可以用dp算出来,后面的由于都没有脱离限制,可以直接算。
时间复杂度,如果将预处理写好了就是(O(32n))的,我图方便就直接乱搞了。
#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long i64;
const int MOD = (int) 1e9 + 3;
const int MAXN = 53;
int n, k;
int A[MAXN];
i64 solve(int cur) {
static i64 f[MAXN][2];
memset(f, 0, sizeof f);
f[0][0] = 1;
i64 upper = (1ll << (cur + 1)) - 1;
for (int i = 0; i < n; i ++)
for (int j = 0; j < 2; j ++) {
if (f[i][j] == 0) continue;
if (A[i] >> cur & 1) {
i64 x = upper - (1 << cur) + 1;
f[i + 1][j] = ((i64) f[i][j] * x + f[i + 1][j]) % MOD;
x = (A[i] & upper) - (1 << cur) + 1;
f[i + 1][j ^ 1] = ((i64) f[i][j] * x + f[i + 1][j ^ 1]) % MOD;
}
else {
i64 x = (A[i] & upper) + 1;
f[i + 1][j] = ((i64) f[i][j] * x + f[i + 1][j]) % MOD;
}
}
i64 ret = 0;
for (int i = 0; i < n; i ++) {
int cnt = 0;
for (int j = i + 1; j < n; j ++) // 可以预处理
cnt += A[j] >> cur & 1;
i64 sum = 0;
if (A[i] >> cur & 1)
for (int j = 0; j < 2; j ++)
if ((k >> cur & 1) == ((j + cnt) & 1))
sum = (sum + f[i][j]) % MOD;
for (int j = i + 1; j < n; j ++) // 可以预处理
if (A[j] >> cur & 1)
sum = sum * ((A[j] & upper) - (1 << cur) + 1) % MOD;
else
sum = sum * ((A[j] & upper) + 1) % MOD;
ret = (ret + sum) % MOD;
}
return ret;
}
int main() {
while (true) {
scanf("%d%d", &n, &k);
if (n == 0 && k == 0) break;
for (int i = 0; i < n; i ++)
scanf("%d", A + i);
i64 ans = 0;
for (int i = 30; i >= 0; i --) {
ans = (ans + solve(i)) % MOD;
int x = 0;
for (int j = 0; j < n; j ++)
x ^= A[j] >> i & 1;
if (x != (k >> i & 1)) break;
if (i == 0) ans = (ans + 1) % MOD;
}
cout << ans << endl;
}
return 0;
}