CF1119H Triple [* hard]
给定 (n,k),以及 (n) 个三元组 ((a_i,b_i,c_i)),常数 (x,y,z) 表示每个三元组均有 (x) 个 (a_i),(y) 个 (b_i),(z) 个 (c_i),求从每个数组中选出一个数使得异或和为 (tin [0,2^k)) 的方案数。
保证 (a_i,b_i,c_i<2^k,nle 10^5,kle 17)
( m Sol:)
首先可以 (mathcal O(n2^kk)) 的做 FWT
然后众所周知 (c(i,j)=(-1)^{i& j}),每个元素都形如 (f_1 imes x^{a_i}+f_2 imes x^{b_i}+f_3 imes x^{c_i}) 这样的一个集合幂级数,可能的 FWT 值仅有 (8) 种,如果能对每个点值 (x) 算出取值数就可以知道答案了。
然后考虑一个极其巧妙的转换,不妨给所有三元组均异或上 (a_i) 然后视为 ((0,a_ioplus b_i,a_ioplus c_i)) 这样的三元组,这样考虑 FWT 点值必然形如 (f_1+f_2 imes (-1)^{i&(...)}+f_3 imes (-1)^{i&(...)})
于是 FWT 数组有且仅有 (4) 种点值 ((f_1+f_2+f_3,f_1-f_2-f_3,f_1+f_2-f_3,f_1-f_2+f_3)),分别设为 ((a,b,c,d))。
于是只需要考虑列出 (4) 个方程来求解这 (4) 个点值的数量(黎明前的巧克力)
- (c_1+c_2+c_3+c_4=n)
事实上,我们进行分类讨论:
- (c_1) 即 (b) 与 (c) 取值均为 (1) 的数量。
- (c_2) 即 (b) 与 (c) 取值均为 (-1) 的数量。
- (c_3) 即 (b) 取值为 (1),(c) 取值为 (-1) 的数量。
- (c_4) 即 (b) 取值为 (-1),(c) 取值为 (1) 的数量。
于是将上值加起来做 FWT 得到的结果应该是 (c_1+c_3-c_2-c_4)
将 (c) 处取值设为 (1),那么 FWT 得到的结果应该是 (c_1+c_4-c_2-c_3) 的值。
然后最后最为巧妙的是,我们令 (boplus c) 为 (1)
这样得到的结果显然就是 (c_1+c_2-c_3-c_4)
于是设我们得到的值分别为 (A,B,C,D),那么就有:
解得:
然后我们进行 IFWT 即可得到答案。
最后这一步考虑正负来计数实在是 tql !
综上,我们得到了一个 (mathcal O(k2^k+n)) 的优秀做法辣!
(Code:)
#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
#define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
#define re register
#define int long long
int gi() {
char cc = getchar() ; int cn = 0, flus = 1 ;
while( cc < '0' || cc > '9' ) { if( cc == '-' ) flus = - flus ; cc = getchar() ; }
while( cc >= '0' && cc <= '9' ) cn = cn * 10 + cc - '0', cc = getchar() ;
return cn * flus ;
}
const int P = 998244353 ;
const int IP = 499122177 ;
const int N = 1e5 + 5 ;
const int M = 3e5 + 5 ;
int n, f1, f2, f3, f4, limit, fA, f[M] ;
struct node { int x, y ; } c[N] ;
struct fucti { int a1, a2, a3, a4 ; } g[M] ;
int fpow( int x, int k ) {
int ans = 1, base = x ;
while(k) {
if(k & 1) ans = ans * base % P ;
base = base * base % P, k >>= 1 ;
} return ans ;
}
void FWT( int *a, int type ) {
for( re int k = 1; k < limit; k <<= 1 )
for( re int i = 0; i < limit; i += ( k << 1 ) )
for( re int j = i; j < i + k; ++ j ) {
int nx = a[j], ny = a[j + k] ;
a[j] = (nx + ny) % P, a[j + k] = (nx - ny + P) % P ;
if( !type ) a[j] = a[j] * IP % P, a[j + k] = a[j + k] * IP % P ;
}
}
void init() {
memset( f, 0, sizeof(f) ) ;
}
void Mod( int &x ) {
x += P, x += P, x %= P ;
}
signed main()
{
n = gi(), limit = gi(), limit = 1 << limit ; int x, y, z ;
x = gi(), y = gi(), z = gi() ;
f1 = x + y + z, f2 = x - y - z, f3 = x + y - z, f4 = x - y + z ;
Mod(f1), Mod(f2), Mod(f3), Mod(f4) ;
rep( i, 1, n ) {
x = gi(), y = gi() ^ x, z = gi() ^ x, fA ^= x ;
c[i].x = y, c[i].y = z ;
}
rep( i, 1, n ) ++ f[c[i].x] ; FWT( f, 1 ) ;
for( re int i = 0; i < limit; ++ i ) g[i].a1 = n, g[i].a2 = f[i] ;
init() ; rep( i, 1, n ) ++ f[c[i].y] ; FWT( f, 1 ) ;
for( re int i = 0; i < limit; ++ i ) g[i].a3 = f[i] ;
init() ; rep( i, 1, n ) ++ f[c[i].x ^ c[i].y] ; FWT( f, 1 ) ;
for( re int i = 0; i < limit; ++ i ) g[i].a4 = f[i] ;
int ip = IP * IP % P ;
for( re int i = 0; i < limit; ++ i ) {
int A = g[i].a1, B = g[i].a2, C = g[i].a3, D = g[i].a4 ;
int c1 = A + B + C + D, c2 = A + D - B - C,
c3 = A + B - C - D, c4 = A + C - B - D ;
Mod(c1), Mod(c2), Mod(c3), Mod(c4),
c1 = c1 * ip % P, c2 = c2 * ip % P, c3 = c3 * ip % P, c4 = c4 * ip % P,
f[i] = fpow( f1, c1 ) * fpow( f2, c2 ) % P * fpow( f3, c3 ) % P * fpow( f4, c4 ) % P ;
}
FWT( f, 0 ) ;
for( re int i = 0; i < limit; ++ i ) printf("%lld ", f[fA ^ i] ) ;
return 0 ;
}