题目描述:
算法标签:斯特林树,分治ntt
思路:
对于固定一个点,讨论他有几个叶子的情况,观察规律发现他的系数斯特林数。于是我们可以求出对于每一个节点建成多少个联通快的方案树,再把每一个节点的方案书卷积起来,用分治ntt维护。
对于求斯特林数,可以先求出小的一部分,对于大的一部分,再用斯特林数的通式求法。
以下代码:
#include<bits/stdc++.h> #define il inline #define LL long long #define vet vector<int> #define _(d) while(d(isdigit(ch=getchar()))) using namespace std; const int N=3e5+5,p=998244353; int G[2],jc[N],ny[N],in[N]; int n,c,k,sz,s[318][318],a[N],b[N],num,sum[N]; il int read(){ int x,f=1;char ch; _(!)ch=='-'?f=-1:f;x=ch^48; _()x=(x<<1)+(x<<3)+(ch^48); return f*x; } il int ksm(LL a,int y){ LL b=1; while(y){ if(y&1)b=b*a%p; a=a*a%p;y>>=1; } return b; } il int mu(int x,int y){ if(x+y>=p)return x+y-p; return x+y; } class ntt{ int v[N],t,l; il void init(int x){ t=1;l=0; while(t<=x)t<<=1,l++; for(int i=0;i<t;i++)v[i]=(v[i>>1]>>1)|((i&1)<<l-1); } il void dft(int *x,int op){ for(int i=0;i<t;i++)if(i<v[i])swap(x[i],x[v[i]]); for(int i=1;i<t;i<<=1){ int wn=ksm(G[op],(p-1)/(i<<1)); for(int j=0;j<t;j+=i<<1){ for(int k=0,w=1;k<i;k++,w=1ll*w*wn%p){ int A=x[j+k],B=1ll*x[i+j+k]*w%p; x[j+k]=mu(A,B);x[i+j+k]=mu(A,p-B); } } } if(op){ int kk=ksm(t,p-2); for(int i=0;i<t;i++)x[i]=1ll*x[i]*kk%p; } } public: il void mult(int x){ init(x); dft(a,0);dft(b,0); for(int i=0;i<t;i++)a[i]=1ll*a[i]*b[i]%p; dft(a,1); } il void clear(){ for(int i=0;i<t;i++)a[i]=b[i]=0; } }T; il vet work(int x){ vet res;res.resize(x+1); if(x<=sz){ for(int i=1;i<=x;i++)res[i]=s[x][i]; } else{ for(int i=0;i<=x;i++){ if(i&1)a[i]=p-ny[i];else a[i]=ny[i]; b[i]=1ll*ksm(i,x)*ny[i]%p; } T.mult(x<<1); for(int i=1;i<=x;i++)res[i]=a[i]; T.clear(); } if(num^x){ for(int i=0;i<x;i++)res[i]=res[i+1]; int tmp=1;res.resize(x); for(int i=1;i<x;i++)tmp=1ll*(c-i)*tmp%p,res[i]=1ll*res[i]*tmp%p; } else{ num=0; int tmp=1; for(int i=1;i<=x;i++)tmp=1ll*(c-i+1)*tmp%p,res[i]=1ll*res[i]*tmp%p; } return res; } il vet Solve(int l,int r){ if(l==r)return work(in[l]); int mid=upper_bound(sum+l,sum+r+1,sum[l-1]+((sum[r]-sum[l-1])>>1))-sum-1; vet res1=Solve(l,mid),res2=Solve(mid+1,r),res; int s1=res1.size(),s2=res2.size(),s=s1+s2-1; res.resize(s); for(int i=0;i<s1;i++)a[i]=res1[i]; for(int i=0;i<s2;i++)b[i]=res2[i]; T.mult(s); for(int i=0;i<s;i++)res[i]=a[i]; T.clear(); return res; } int main() { n=read();c=read();k=read();sz=(int)sqrt(n); jc[0]=1;for(int i=1;i<=n;i++)jc[i]=1ll*i*jc[i-1]%p; ny[n]=ksm(jc[n],p-2);for(int i=n;i;i--)ny[i-1]=1ll*i*ny[i]%p; s[0][0]=1;G[0]=3;G[1]=ksm(3,p-2); for(int i=1;i<=sz;i++){ for(int j=1;j<=i;j++) s[i][j]=mu(s[i-1][j-1],1ll*j*s[i-1][j]%p); } for(int i=1;i<n;i++)in[read()]++,in[read()]++; num=in[1];sort(in+1,in+1+n); for(int i=1;i<=n;i++)sum[i]=sum[i-1]+in[i]; vet res=Solve(1,n); int ans=0; for(int i=1;i<n;i++)ans=mu(ans,1ll*ksm(i,k)*res[i]%p); printf("%d ",ans); return 0; }