题意 : 给出 N 堆石子,每次可以选择一堆石子拿走任意颗石子,最后没有石子拿的人为败者。现在后手 Bob 可以在游戏开始前拿掉不超过 d 堆的整堆石子,现在问你有几种取走的组合使得 Bob 能保证他在游戏开始后是必胜的。
分析 :
在没有附加规则,即 Bob 可以先取走某些堆的情况下
就是个简单的 Nim 博弈模型,后手必胜当且仅当各个堆的石子的数目的异或和为 0
那么题目就变成了,问有多少种取走组合使得剩下的石子的异或和为 0
可以发现,可取走的石子的堆数 d 的上限不大,所以这个问题可以用 DP 解决
定义 dp[i][j][k] = 到第 i 堆石子为止,取走 j 堆石子,异或和为 k 的方案数有多少种
由于异或的自反性质,如果要从一个异或和集合中删除某个数,那么就相当于用这个数去异或这个集合的异或和
那么可以根据这个写出状态转移方程如下
dp[i][j][k] = dp[i-1][j][k] + dp[i-1][j-1][k^pile[i]]
意义为 当前的DP值可以从取了这堆石子就能将异或和变为 k 的状态转移而来
那么就要求从异或和 k 中删除 pile[i] ,即直接拿 k 去异或 pile[i] 即可
也因为由于有这个性质,设 pile[1]^pile[2]...^pile[n] 原所有石子的异或和为 aim
那么最后的答案就存在 dp[n][1~d][aim] 中,意义为 取出的石子的异或和为 aim 的话
那么相当于从还未被取走任何一堆石子的所有的异或和 aim 中取走 aim 那么剩下的异或和就为 0
所以答案在 dp[n][1~d][aim] 中,在写 DP 的时候注意模就行了
#include<bits/stdc++.h> #define LL long long #define ULL unsigned long long #define scs(i) scanf("%s", i) #define sci(i) scanf("%d", &i) #define scd(i) scanf("%lf", &i) #define scl(i) scanf("%lld", &i) #define scIl(i) scanf("%I64d", &i) #define scii(i, j) scanf("%d %d", &i, &j) #define scdd(i, j) scanf("%lf %lf", &i, &j) #define scll(i, j) scanf("%lld %lld", &i, &j) #define scIll(i, j) scanf("%I64d %I64d", &i, &j) #define sciii(i, j, k) scanf("%d %d %d", &i, &j, &k) #define scddd(i, j, k) scanf("%lf %lf %lf", &i, &j, &k) #define sclll(i, j, k) scanf("%lld %lld %lld", &i, &j, &k) #define scIlll(i, j, k) scanf("%I64d %I64d %I64d", &i, &j, &k) #define lson l, m, rt<<1 #define rson m+1, r, rt<<1|1 #define lowbit(i) (i & (-i)) #define mem(i, j) memset(i, j, sizeof(i)) #define fir first #define sec second #define ins(i) insert(i) #define pb(i) push_back(i) #define pii pair<int, int> #define mk(i, j) make_pair(i, j) #define pll pair<long long, long long> using namespace std; const int maxn = 1000 + 50; const int mod = 1e9 + 7; int dp[maxn][15][maxn], arr[maxn]; int main(void) { int nCase; sci(nCase); while(nCase--){ int n, d; scii(n, d); d = min(d, n); int aim = 0, mx = 0; for(int i=1; i<=n; i++){ sci(arr[i]); mx = max(arr[i], mx); aim ^= arr[i]; } mem(dp, 0); for(int i=1; i<=n; i++) dp[i][1][arr[i]]++; for(int i=0; i<=10; i++) if((1<<i) > mx){ mx = (1<<i); break; } for(int i=1; i<=n; i++) for(int j=1; j<=d; j++) for(int k=0; k<=mx; k++) dp[i][j][k] = (dp[i][j][k]%mod + (dp[i-1][j][k] + dp[i-1][j-1][k^arr[i]])%mod)%mod; int ans = (aim==0) ? 1 : 0; for(int i=1; i<=d; i++) ans = (ans + dp[n][i][aim])%mod; printf("%d ", ans); } return 0; }