Description
有一个 n 行 m 列的表格,行从 0 到 n−1 编号,列从 0 到 m−1 编号。每个格子都储存着能量。最初,第 i 行第 j 列的格子储存着 (i xor j) 点能量。所以,整个表格储存的总能量是,
随着时间的推移,格子中的能量会渐渐减少。一个时间单位,每个格子中的能量都会减少 1。显然,一个格子的能量减少到 0 之后就不会再减少了。
也就是说,k 个时间单位后,整个表格储存的总能量是,
给出一个表格,求 k 个时间单位后它储存的总能量。
由于总能量可能较大,输出时对 p 取模。
Input
第一行一个整数 T,表示数据组数。接下来 T 行,每行四个整数 n、m、k、p。
Output
共 T 行,每行一个数,表示总能量对 p 取模后的结果
Sample Input
3
2 2 0 100
3 3 0 100
3 3 1 100
2 2 0 100
3 3 0 100
3 3 1 100
Sample Output
2
12
6
12
6
HINT
T=5000,n≤10^18,m≤10^18,k≤10^18,p≤10^9
好恶心的数位DP,先将三个串二进制拆分,然后设g[len][S]表示前len位状态为S的方案数,f[len][S]表示前len位状态为S的(i xor j)-k的结果。
S包括i与n的大小关系,j与m的大小关系,i xor j与k的大小关系,然后使劲讨论就行了。
#include<cstdio> #include<cctype> #include<queue> #include<cstring> #include<algorithm> #define rep(i,s,t) for(int i=s;i<=t;i++) #define dwn(i,s,t) for(int i=s;i>=t;i--) #define ren for(int i=first[x];i;i=next[i]) using namespace std; const int BufferSize=1<<16; char buffer[BufferSize],*head,*tail; inline char Getchar() { if(head==tail) { int l=fread(buffer,1,BufferSize,stdin); tail=(head=buffer)+l; } return *head++; } typedef long long ll; inline ll read() { ll x=0,f=1;char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } ll xp[70],f[70][2][2][2],g[70][2][2][2],n,m,k; //c1=1 -> x<n c2=1 -> y<m c3=1 -> x^y>k int p,bitn[70],lenn,bitm[70],lenm,bitk[70],lenk; void solve() { memset(bitn,0,sizeof(bitn)); memset(bitm,0,sizeof(bitm)); memset(bitk,0,sizeof(bitk)); memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); lenn=lenm=lenk=0; while(n) bitn[lenn++]=n&1,n>>=1; while(m) bitm[lenm++]=m&1,m>>=1; while(k) bitk[lenk++]=k&1,k>>=1; int len=max(lenn,max(lenm,lenk)); xp[0]=1;rep(i,1,len) xp[i]=(xp[i-1]*2)%p; rep(c1,0,1) rep(c2,0,1) rep(c3,0,1) { ll &ans=f[0][c1][c2][c3],&ans2=g[0][c1][c2][c3]; rep(x,0,(c1?1:bitn[0]-1)) rep(y,0,(c2?1:bitm[0]-1)) { if((x^y)>=bitk[0]) (ans+=((x^y)-bitk[0]))%=p,ans2++; else if(c3) (ans+=((x^y)-bitk[0]))%=p,ans2++; } } rep(i,1,len-1) rep(c1,0,1) rep(c2,0,1) rep(c3,0,1) { ll &ans=f[i][c1][c2][c3],&ans2=g[i][c1][c2][c3]; rep(x,0,max(bitn[i],c1)) rep(y,0,max(bitm[i],c2)) { if((x^y)>=bitk[i]) { (ans+=f[i-1][c1|(x<bitn[i])][c2|(y<bitm[i])][c3|((x^y)>bitk[i])]+g[i-1][c1|(x<bitn[i])][c2|(y<bitm[i])][c3|((x^y)>bitk[i])]*((x^y)-bitk[i])*xp[i])%=p; (ans2+=g[i-1][c1|(x<bitn[i])][c2|(y<bitm[i])][c3|((x^y)>bitk[i])])%=p; } else if(c3) { (ans+=f[i-1][c1|(x<bitn[i])][c2|(y<bitm[i])][c3|((x^y)>bitk[i])]+g[i-1][c1|(x<bitn[i])][c2|(y<bitm[i])][c3|((x^y)>bitk[i])]*((x^y)-bitk[i])*xp[i])%=p; (ans2+=g[i-1][c1|(x<bitn[i])][c2|(y<bitm[i])][c3|((x^y)>bitk[i])])%=p; } } } printf("%lld ",(f[len-1][0][0][0]+p)%p); } int main() { dwn(T,read(),1) { n=read();m=read();k=read();p=read(); solve(); } return 0; }