题目大意
给出(n,t,x,y,z),值域(le 2^t),给出(n)个三元组((a_i,b_i,c_i)),表示有(x)个(a_i),(y)个(b_i),(z)个(c_i)。对于每个(kin [0,2^t-1]),求出从每组选出一个数的异或值为(k)的方案数。
思路
先定义([delta ^n]F(delta))表示多项式(F)的第(n)位系数。
不得不说,这道题真的很妙。而且也加深了我对FWT的理解。
如果我们对每个三元组创建生成函数为:
那么对于(k)的答案就是:
这里的乘法定义为异或卷积,就是([delta^n]F(delta) G(delta)=sum_{iotimes j=n} [delta^i] F(delta) [delta^j] G(delta))。
可以看出来,上面这个式子可以使用( ext {FWT})进行优化,但是(Theta(nk2^k))的时间复杂度显然通过不了这道题。
但是这道题十分特殊,因为每个多项式只有(3)个点有值。于是,我们可以考虑在( ext {FWT})的过程中进行优化。
我们可以考虑到,对于([delta^i]FWT(F_n))有
其中,(cnt(i))表示(i)在二进制下(1)的个数。
如果我们按照上面的构造的话,时间复杂度虽然降为(Theta((n+k)2^k)),但是仍然无法通过。
但是我们发现对于([delta^i]FWT(F_n)),它的值只可能有(8)种,形式为(x+y+z,x+y-z,...)。于是,我们可以自然地想到一种方法,即计算对于([delta ^i]prod FWT(F_n))每一种出现了多少次,然后直接快速幂即可。
但是(8)种显然太多了,我们可以通过把一个三元组((a_i,b_i,c_i))变为((0,b_iotimes a_i,c_iotimes a_i)),最后答案再异或上(otimes_{i=1}^{n} a_i)。这样我们就只有(4)种情况了。
为了方便,我们设(S =prod_{i=1}^{n} FWT(F_i))。这里的乘法就是([delta^i]F(delta) G(delta) =[delta^i]F(delta) [delta^i]G(delta))。我们对于([delta^i]S)可以设(c1,c2,c3,c4)表示:
于是,我们的目标就是通过设立(4)个本质不同的一次方程解出(c1,c2,c3,c4)。
首先,非常显然:
然后呢?我们发现其实(c1,c2,c3,c4)跟(x,y,z)其实半毛钱都没有,于是我们可以对(x,y,z)代入特定值解出(c1,c2,c3,c4)。
- (x=0,y=1,z=0)
这样,我们就可以得到:
非常显然,你把(x=0,y=1,z=0)带进去就好了。
- (x=0,y=0,z=1)
与上面一样,我们可以得到:
但是我们现在似乎只有(3)个方程。。。蛤?你说(x=1,y=0,z=0),显然它与((1))式本质相同。
于是,我们现在需要一点大大的脑洞。
- ([delta^{b_kotimes c_k}]F(delta)=1)
我们发现这种情况其实就是上面两种情况的卷积,就可以得到:
于是,我们可以得到:
于是,我们就可以得到:
于是,我们就可以在(Theta(n+k2^k))时间内解决。
( ext {Code})
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define inv2 499122177
#define mod 998244353
#define int long long
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
int quick_pow (int a,int b){
int res = 1;
while (b){
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
int inv (int x){return quick_pow (x,mod - 2);}
int n,t,lim,f[1 << 17],f1[1 << 17],f2[1 << 17],f3[1 << 17];
void DWT (int *a){
for (Int i = 1;i < lim;i <<= 1)
for (Int j = 0;j < lim;j += i << 1)
for (Int k = 0;k < i;++ k){
int x = a[j + k],y = a[j + k + i];
a[j + k] = x + y,a[i + j + k] = x - y;
}
}
void IDWT (int *a){
for (Int i = 1;i < lim;i <<= 1)
for (Int j = 0;j < lim;j += i << 1)
for (Int k = 0;k < i;++ k){
int x = a[j + k],y = a[j + k + i];
a[j + k] = (x + y) * inv2 % mod,a[i + j + k] = (x + mod - y) * inv2 % mod;
}
}
signed main(){
read (n,t),lim = 1 << t;int x,y,z,sta = 0;read (x,y,z);
for (Int i = 1,a,b,c;i <= n;++ i) read (a,b,c),sta ^= a,b ^= a,c ^= a,f1[b] ++,f2[c] ++,f3[b ^ c] ++;
DWT (f1),DWT (f2),DWT (f3);
int s1 = (x + y + z) % mod,s2 = (x + y - z) % mod,s3 = (x - y + z) % mod,s4 = (x - y - z) % mod;
for (Int i = 0;i < lim;++ i){
int c1 = (n + f1[i] + f2[i] + f3[i]) / 4;
f[i] = quick_pow (s1,c1) *
quick_pow (s2,(n + f1[i] - c1 * 2) / 2) % mod *
quick_pow (s3,(n + f2[i] - c1 * 2) / 2) % mod *
quick_pow (s4,(n + f3[i] - c1 * 2) / 2) % mod;
}
IDWT (f);
for (Int i = 0;i < lim;++ i) write ((f[i ^ sta] + mod) % mod),putchar (' ');
putchar ('
');
return 0;
}
参考博客
https://www.luogu.com.cn/blog/command-block/solution-cf1119h