题目大意:有 $n$ 个互不相同的正整数 $c_i$。问对于每一个 $1le ile m$,有多少个不同形态(考虑结构和点权)的二叉树满足每个点权都在 $c$ 中出现过,且点权和为 $i$。答案对 $998244353$ 取模。
$1le n,mle 10^5$。
首先考虑DP,$f_i$ 表示点权和为 $i$ 的树数。
那么枚举根节点的点权和两棵子树的点权和 $f_k=sumlimits^n_{i=1}c_isumlimits^{k-c_i}_{j=0}f_jf_{k-c_i-j}$。
初始状态 $f_0=1$。因为空树也能作为子树。
这样的复杂度是 $O(nm^2)$,不能过。
考虑 $c$ 的生成函数 $C(x)=sum x^{c_i}$ 和 $f$ 的生成函数 $F(x)=sum f_ix^i$。(你问我怎么想到的?我也不知道啊)
那么容易发现原来的式子就是几个函数的卷积。
$F=C imes F imes F+1$(注意 $f_0=1$)
$C imes F^2-F+1=0$
$F=dfrac{1pmsqrt{1-4C}}{2C}$
接下来看看上面该取正还是负。
取正时 $limlimits_{x ightarrow 0}F(x)=+infty$,不收敛,舍去。
取负时 $limlimits_{x ightarrow 0}F(x)=1$,符合题意。
那么 $F=dfrac{1-sqrt{1-4C}}{2C}=dfrac{2}{1+sqrt{1-4C}}$。
直接套模板即可。时间复杂度 $O(mlog m)$。
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=333333,mod=998244353; #define FOR(i,a,b) for(int i=(a);i<=(b);i++) #define ROF(i,a,b) for(int i=(a);i>=(b);i--) #define MEM(x,v) memset(x,v,sizeof(x)) inline int read(){ char ch=getchar();int x=0,f=0; while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar(); while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar(); return f?-x:x; } int n,m,c[maxn],lim,l,rev[maxn],invtmp[maxn],Binv[maxn],sqrtmp[maxn],Csqrt[maxn],Cinv[maxn]; inline void init(int upr){ for(lim=1,l=0;lim<upr;lim<<=1,l++); FOR(i,0,lim-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1)); } inline int add(int a,int b){return a+b<mod?a+b:a+b-mod;} inline int sub(int a,int b){return a<b?a-b+mod:a-b;} inline int qpow(int a,int b){ int ans=1; for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) ans=1ll*ans*a%mod; return ans; } void NTT(int *A,int tp){ FOR(i,0,lim-1) if(i<rev[i]) swap(A[i],A[rev[i]]); for(int i=1;i<lim;i<<=1) for(int j=0,r=i<<1,Wn=qpow(3,mod-1+tp*(mod-1)/r);j<lim;j+=r) for(int k=0,w=1;k<i;k++,w=1ll*w*Wn%mod){ int x=A[j+k],y=1ll*A[i+j+k]*w%mod; A[j+k]=add(x,y);A[i+j+k]=sub(x,y); } if(tp==-1) for(int i=0,linv=qpow(lim,mod-2);i<lim;i++) A[i]=1ll*A[i]*linv%mod; } void poly_inv(int *A,int *B,int deg){ if(deg==1) return void(B[0]=qpow(A[0],mod-2)); poly_inv(A,B,(deg+1)>>1); init(deg<<1); FOR(i,0,deg-1) invtmp[i]=A[i]; FOR(i,deg,lim-1) invtmp[i]=0; NTT(invtmp,1);NTT(B,1); FOR(i,0,lim-1) B[i]=1ll*sub(2,1ll*invtmp[i]*B[i]%mod)*B[i]%mod; NTT(B,-1); FOR(i,deg,lim-1) B[i]=0; } void poly_sqrt(int *A,int *B,int deg){ if(deg==1) return void(B[0]=1); poly_sqrt(A,B,(deg+1)>>1); init(deg<<1); FOR(i,0,lim-1) Binv[i]=0; poly_inv(B,Binv,deg); init(deg<<1); FOR(i,0,deg-1) sqrtmp[i]=A[i]; FOR(i,deg,lim-1) Binv[i]=sqrtmp[i]=0; NTT(sqrtmp,1);NTT(Binv,1); FOR(i,0,lim-1) sqrtmp[i]=1ll*sqrtmp[i]*Binv[i]%mod; NTT(sqrtmp,-1); FOR(i,0,deg-1) B[i]=499122177ll*add(B[i],sqrtmp[i])%mod; FOR(i,deg,lim-1) B[i]=0; } int main(){ n=read();m=read(); FOR(i,1,n){ int x=read(); if(x<=m) c[x]=1; } FOR(i,1,m) c[i]=(mod-4ll*c[i]%mod)%mod; c[0]=1; poly_sqrt(c,Csqrt,m+1); Csqrt[0]=add(Csqrt[0],1); poly_inv(Csqrt,Cinv,m+1); FOR(i,1,m) printf("%d ",add(Cinv[i],Cinv[i])); }