[bzoj3625][Codeforces Round #250]小朋友和二叉树
标签: 多项式开方 多项式求逆
Description
一棵树的所有点的点权都是给定的集合C中的一个数。
让你求出1到m中所有权值为i的树的个数。
两棵树不同当且仅当树的形态不一样或者是树的某个点的点权不一样
对(998244353)取模
(n,m<=10^5)
(jiade)Solution
递推式应该挺好写的。
设(c(i))表示i是否在集合C中。
[f(n)=sum_{i=1}^n c(i)sum _{j=0}^{n-i}f(j)f(n-i-j)
]
特别的,(f(0)=1)
然后把(f(i))提出来。
[f(n)=sum_{i=0}^{n-1}f(i)sum _{j=0}^{n-i}f(j)c(n-i-j)
]
可以注意到这是一个卷积套卷积的形式。
设(g(n)=sum_{j=0}^nf(i)c(n-i))
那么$$f(n)=sum_{j=0}^{n-1}f(i)×g(n-i)$$
cdq+NTT即可。
##Real Solution 应该没有人会信我的鬼话吧 ~~其实我一开始是这么写的~~ 在cdq分治的时候,首先会求出[l,mid]内的$f(i)$ 然后再与[1,r-l+1]中的$g(i)$做卷积。 可是当l=1时,有些$g(i)$是还没有求出来的。 所以卷积个鬼啊。
我们设生成函数(F(i)=sum_{i=0}^n f(i)x^i),(C(i)=sum _{i=1}c(i)x^i)
那么有(F(x)=F(x)^2C(x)+1)
解得$$F(x)={1+(-) {sqrt {1-4C(x)}} over 2C(x)}$$
即$$F(x)={2 over {1-(+)sqrt {1-4C(x)}}}$$
显然符号是取(+)号的。
因为当(x=0)如果取(-)号那么分母为0没有意义。
所以$$F(x)={2 over {1+sqrt {1-4C(x)}}}$$
多项式开方和多项式求逆即可。
Code
给一份假的代码。
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<set>
#include<map>
using namespace std;
#define ll long long
#define REP(i,a,b) for(int i=(a),_end_=(b);i<=_end_;i++)
#define DREP(i,a,b) for(int i=(a),_end_=(b);i>=_end_;i--)
#define EREP(i,a) for(int i=start[(a)];i;i=e[i].next)
inline int read()
{
int sum=0,p=1;char ch=getchar();
while(!(('0'<=ch && ch<='9') || ch=='-'))ch=getchar();
if(ch=='-')p=-1,ch=getchar();
while('0'<=ch && ch<='9')sum=sum*10+ch-48,ch=getchar();
return sum*p;
}
const int mod=998244353;
const int maxn=4e5+20;
int n,m,f[maxn],g[maxn],ok[maxn],rev[maxn],A[maxn],B[maxn];
inline int power(int a,int b)
{
int ans=1;
while(b)
{
if(b & 1)ans=(ll)ans*a%mod;
b>>=1;
a=(ll)a*a%mod;
}
return ans;
}
inline void init()
{
m=read();n=read();
REP(i,1,m)ok[read()]=1;
ok[0]=1;
}
inline void NTT(int *p,int n,int op)
{
REP(i,0,n-1)if(i<rev[i])swap(p[i],p[rev[i]]);
for(int i=1;i<n;i<<=1)
{
int W=power(3,(mod-1)/(i<<1));
for(int j=0;j<n;j+=i<<1)
{
int w=1;
for(int k=j;k<i+j;k++,w=(ll)W*w%mod)
{
int x=p[k],y=(ll)p[k+i]*w%mod;
p[k]=x+y;p[k+i]=x-y;
if(p[k]>=mod)p[k]-=mod;
if(p[k+i]<0)p[k+i]+=mod;
}
}
}
if(op==-1)
{
int inv=power(n,mod-2);
REP(i,0,n-1)p[i]=(ll)p[i]*inv%mod;
reverse(p+1,p+n);
}
}
void solve(int l,int r)
{
if(l==r){g[l]=(g[l]+ok[l])%mod;f[l]=(f[l]+g[l])%mod;return;}
int mid=(l+r)>>1;
solve(l,mid);
int N=1,L=0;
while(N<=2*(mid-l+1))N<<=1,L++;
REP(i,1,N-1)rev[i]=(rev[i>>1]>>1)|(1<<(L-1));
/*--- get f(i) ---*/
REP(i,0,N-1)A[i]=B[i]=0;
REP(i,0,mid-l)A[i]=f[i+l];REP(i,1,mid-l+1)B[i]=g[i];
NTT(A,N,1);NTT(B,N,1);
REP(i,0,N-1)A[i]=(ll)A[i]*B[i]%mod;
NTT(A,N,-1);
REP(i,mid-l+1,r-l)f[i+l]=(f[i+l]+A[i])%mod;
/*--- get g(i) ---*/
REP(i,0,N-1)A[i]=B[i]=0;
REP(i,0,mid-l)A[i]=f[i+l];REP(i,1,mid-l+1)B[i]=ok[i];
NTT(A,N,1);NTT(B,N,1);
REP(i,0,N-1)A[i]=(ll)A[i]*B[i]%mod;
NTT(A,N,-1);
REP(i,mid-l+1,r-l)g[i+l]=(g[i+l]+A[i])%mod;
solve(mid+1,r);
}
inline void doing()
{
f[0]=g[0]=1;
solve(1,n);
REP(i,1,n)printf("%d
",f[i]);
}
int main()
{
init();
doing();
return 0;
}
正确代码
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<set>
#include<map>
using namespace std;
#define ll long long
#define REP(i,a,b) for(int i=(a),_end_=(b);i<=_end_;i++)
#define DREP(i,a,b) for(int i=(a),_end_=(b);i>=_end_;i--)
#define EREP(i,a) for(int i=start[(a)];i;i=e[i].next)
inline int read()
{
int sum=0,p=1;char ch=getchar();
while(!(('0'<=ch && ch<='9') || ch=='-'))ch=getchar();
if(ch=='-')p=-1,ch=getchar();
while('0'<=ch && ch<='9')sum=sum*10+ch-48,ch=getchar();
return sum*p;
}
const int maxn=4e5+20;
const int mod=998244353;
const int inv2=(mod+1)/2;
int n,c[maxn],m,rev[maxn];
inline int power(int a,int b)
{
int ans=1;
while(b)
{
if(b & 1)ans=(ll)ans*a%mod;
b>>=1;
a=(ll)a*a%mod;
}
return ans;
}
inline void NTT(int *p,int N,int op)
{
int n=1,l=0;while(n<N)n<<=1,l++;
REP(i,1,n-1)rev[i]=(rev[i>>1]>>1)|((i & 1)<<(l-1));
REP(i,0,n-1)if(i<rev[i])swap(p[i],p[rev[i]]);
for(int i=1;i<n;i<<=1)
{
int W=power(3,(mod-1)/(i<<1));
for(int j=0;j<n;j+=i<<1)
{
int w=1;
for(int k=j;k<i+j;k++,w=(ll)w*W%mod)
{
int x=p[k],y=(ll)p[k+i]*w%mod;
p[k]=x+y;p[k+i]=x-y;
if(p[k]>mod)p[k]-=mod;
if(p[k+i]<0)p[k+i]+=mod;
}
}
}
if(op==-1)
{
int inv=power(n,mod-2);
REP(i,0,n-1)p[i]=(ll)p[i]*inv%mod;
reverse(p+1,p+n);
}
}
int A[maxn],B[maxn],C[maxn];
int tmp[maxn];
void Inv(int *p,int *q,int len)
{
if(len==1)
{
q[0]=power(p[0],mod-2);
return;
}
Inv(p,q,len>>1);
REP(i,0,len-1)A[i]=p[i],B[i]=q[i];
NTT(A,len<<1,1);NTT(B,len<<1,1);
REP(i,0,(len<<1)-1)A[i]=(ll)B[i]*B[i]%mod*A[i]%mod;
NTT(A,len<<1,-1);
REP(i,0,len-1)q[i]=((-A[i]+2*q[i])%mod+mod)%mod;
REP(i,0,len<<1)A[i]=B[i]=0;
}
void Sqrt(int *p,int *q,int len)
{
if(len==1)
{
q[0]=p[0];
return;
}
Sqrt(p,q,len>>1);
REP(i,0,len)C[i]=p[i];
Inv(q,tmp,len);
NTT(tmp,len<<1,1);NTT(C,len<<1,1);
REP(i,0,(len<<1)-1)tmp[i]=(ll)tmp[i]*C[i]%mod;
NTT(tmp,len<<1,-1);
REP(i,0,len-1)q[i]=(ll)(tmp[i]+q[i])*inv2%mod;
REP(i,0,len<<1)C[i]=tmp[i]=0;
}
void init()
{
m=read();n=read();
REP(i,1,m)c[read()]=1;
}
int d[maxn];
void doing()
{
int N=1,l=0;while(N<=n)N<<=1,l++;
REP(i,0,N-1)c[i]=(-4*c[i]+mod)%mod;
c[0]++;
Sqrt(c,d,N);
REP(i,0,N-1)c[i]=0;
d[0]=(d[0]+1)%mod;
Inv(d,c,N);
REP(i,0,N-1)c[i]=c[i]*2%mod;
REP(i,1,n)printf("%d
",c[i]);
}
int main()
{
init();
doing();
return 0;
}