题意:给定一个大小为(n)的点权集合,所有的点的点权必须是集合里的点权,问这样形成的二叉树的点权和为([1-m])的有多少个
题解:
首先我们可以设(G(i))表示(i)这个数是否在集合中,(f(i))表示点权和为(i)的方案数
那么我们可以考虑枚举一个点的点权和ta的两个子树的点权和
(f(0)=1\
f(k)=sum_{i=1}^{k}g(i)sum_{j=0}^{i+jle k}{f(j)f(k-i-j)})
所以可以发现(f=g*f^2)
直接解方程就可以得到(f=frac{1 pm sqrt{1-4g}}{2g})
这个(pm)该怎么处理?
先试试分母有理化
(f=frac{2}{1 pm sqrt{1-4g}})
可以发现一个问题
对于(f(0):g(0)=0 ; ; => ; ; 1-4g = 1)
要保证分母必须不为(0)
所以应该是(f=frac{2}{1 + sqrt{1-4g}})
多项式开根+求逆即可
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
# define LL long long
using namespace std ;
const int M = 400050 ;
const int mod = 998244353 ;
const int G = 3 ;
const int Gi = mod / G + 1 ;
inline int read() {
char c = getchar() ; int x = 0 , w = 1 ;
while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
return x * w ;
}
int n , m ;
int len , lim = 1 , r[M] ;
LL g[M] , a[M] , b[M] , c[M] , ans[M] , inv2 ;
inline LL Fpw(LL Base , LL k) {
LL temp = 1 ;
while(k) {
if(k & 1) temp = temp * Base % mod ;
Base = Base * Base % mod ; k >>= 1 ;
}
return temp ;
}
inline void NTT(LL *A , int unit) {
for(int i = 0 ; i < lim ; i ++) if(r[i] > i) swap(A[i] , A[r[i]]) ;
for(int mid = 1 ; mid < lim ; (mid <<= 1)) {
LL R = (mid << 1) , W = Fpw(unit > 0 ? G : Gi , (mod - 1) / R) ;
for(int j = 0 ; j < lim ; j += R) {
LL w = 1 ;
for(int k = 0 ; k < mid ; k ++ , w = (w * W) % mod) {
LL x = A[j + k] , y = w * A[j + k + mid] % mod ;
A[j + k] = (x + y) % mod ; A[j + k + mid] = (x - y) % mod ;
}
}
}
}
void Inv(int d , LL *A , LL *b) {
if(d == 1) { b[0] = Fpw(A[0] , mod - 2) ; return ; }
Inv((d + 1) >> 1 , A , b) ;
lim = 1 ; len = 0 ;
while(lim < (d << 1)) lim <<= 1 , ++ len ;
for(int i = 1 ; i < lim ; i ++) r[i] = ((r[i >> 1] >> 1) | ((i & 1) << (len - 1))) ;
for(int i = 0 ; i < d ; i ++) a[i] = A[i] ;
for(int i = d ; i < lim ; i ++) a[i] = 0 ;
NTT(b , 1) ; NTT(a , 1) ;
for(int i = 0 ; i < lim ; i ++) b[i] = ((2 - a[i] * b[i]) % mod * b[i] % mod + mod) % mod ;
NTT(b , -1) ; LL inv = Fpw(lim , mod - 2) ;
for(int i = 0 ; i < d ; i ++) b[i] = (b[i] * inv % mod + mod) % mod ;
for(int i = d ; i < lim ; i ++) b[i] = 0 ;
}
void Sqrt(int d , LL *A , LL *b) {
if(d == 1) { b[0] = 1 ; return ; }
Sqrt((d + 1) >> 1 , A , b) ;
for(int i = 0 ; i < (d << 1) ; i ++) c[i] = 0 ;
Inv(d , b , c) ;
lim = 1 ; len = 0 ;
while(lim < (d << 1)) lim <<= 1 , ++ len ;
for(int i = 1 ; i < lim ; i ++) r[i] = ((r[i >> 1] >> 1) | ((i & 1) << (len - 1))) ;
for(int i = 0 ; i < d ; i ++) a[i] = A[i] ;
for(int i = d ; i < lim ; i ++) a[i] = 0 ;
NTT(a , 1) ; NTT(c , 1) ;
for(int i = 0 ; i < lim ; i ++) a[i] = (a[i] * c[i] % mod * inv2 % mod + mod) % mod ;
NTT(a , -1) ; LL inv = Fpw(lim , mod - 2) ;
for(int i = 0 ; i < d ; i ++) a[i] = (a[i] * inv % mod + mod) % mod ;
for(int i = d ; i < lim ; i ++) a[i] = 0 ;
for(int i = 0 ; i < d ; i ++) b[i] = ((b[i] * inv2 % mod + a[i]) % mod + mod) % mod ;
for(int i = d ; i < lim ; i ++) b[i] = 0 ;
}
int main() {
n = read() ; m = read() ; inv2 = Fpw(2 , mod - 2) ;
for(int i = 1 , x ; i <= n ; i ++) {
x = read() ;
g[x] = (g[x] - 4 + mod) % mod ;
}
g[0] ++ ; Sqrt(m + 1 , g , b) ;
b[0] ++ ; Inv(m + 1 , b , ans) ;
for(int i = 1 ; i <= m ; i ++)
printf("%lld
",(2LL * ans[i] % mod + mod) % mod) ;
return 0 ;
}