分治FFT
考虑(F(i)),显然因为(F)是个卷积的形式(虽然我们不知道其中的某一部分),因此有:
[F(x)=sum_{i+j=x} F(i)G(j)
]
因此考虑我们计算出了前边一段的(F)值,可以通过乘上(G)中的一部分让这个(F)整体右移,如(F(1)-F(3))卷上(G(1)-G(3))就成为了(F(4)-F(6))中得的一部分。
因此考虑分治。
考虑我们已经计算出了一段([l,mid])中的真实(F)值,我们给右边的部分加上这些的贡献。
那么很不显然就是(F[l,mid])卷上一个(G[0,r-l])就得到了(F[mid,r])的一部分。
那么我们每次分治计算左区间后,(NTT)计算出左边对右边的贡献,然后累加上去即可。
对比代码理解更好哦QAQ。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define N 500005
#define pb push_back
#define g 3
#define gi 332748118
#define mod 998244353
#define int long long
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*f;
}
int n,rev[N];
vector<int>F,G,S,T;
int ksm(int a,int b)
{
int res=1;
while(b)
{
if(b&1)res*=a,res%=mod;
a*=a;a%=mod;b>>=1;
}
return res%mod;
}
void NTT(vector<int>&a,int limit,int type)
{
for(int i=0;i<limit;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
int Wn=ksm(type==1?g:gi,(mod-1)/(mid<<1));
for(int j=0;j<limit;j+=(mid<<1))
{
int w=1;
for(int k=0;k<mid;k++,w=(w*Wn%mod)%mod)
{
int x=a[j+k]%mod,y=w*a[j+k+mid]%mod;
a[j+k]=(x+y)%mod;
a[j+k+mid]=(x-y+mod)%mod;
}
}
}
if(type==-1)
{
int INV=ksm(limit,mod-2);
for(int i=0;i<limit;i++)a[i]=a[i]*INV%mod;
}
}
int get_limit(int x)
{
int limit=1;while(limit<=x)limit<<=1;
for(int i=0;i<limit;i++)rev[i]=((rev[i>>1]>>1)|((i&1)?limit>>1:0));
return limit;
}
vector<int> operator*(vector<int>&a,vector<int>&b)
{
int len=a.size()+b.size()-1;
int limit=get_limit(len);
a.resize(limit);b.resize(limit);
NTT(a,limit,1);NTT(b,limit,1);
for(int i=0;i<limit;i++)a[i]=a[i]*b[i]%mod;
NTT(a,limit,-1);a.resize(len);
return a;
}
void solve(int l,int r)
{
if(l==r)return;
int mid=(l+r)>>1;
solve(l,mid);
S.clear();T.clear();
for(int i=l;i<=mid;i++)S.pb(F[i]),T.pb(G[i-l]);
for(int i=mid+1;i<=r;i++)S.pb(0),T.pb(G[i-l]);
S=S*T;
for(int i=mid+1;i<=r;i++)F[i]=(F[i]+S[i-l])%mod;
solve(mid+1,r);
}
signed main()
{
n=read();G.pb(0);F.pb(1);
for(int i=1;i<n;i++)G.pb(read());
for(int i=1;i<n;i++)F.pb(0);
solve(0,n-1);
for(int i=0;i<n;i++)printf("%d ",F[i]);
return 0;
}