【题解】CodeChef - TREDEG (prufer+生成函数+多项式exp)
好毒瘤的数据范围...
先转prufer,现在问题就变成了我要生成一个(n-2)长度的序列,每一种序列的权值定义为每种数的(prod)(每种数出现个数+1),可以直接使用指数型生成函数生成,具体的:
[(sum_{i=0} {(i+1)^kover i!}x^i)^n[x^{n-2}](n-2)!
]
这个就生成这个序列的答案了。用exp搞个快速幂就完事了。
最终答案的式子
[(sum_{i=0} {(i+1)^kover i!}x^i)^n[x^{n-2}](n-2)!over n^{n-2}
]
然后数据范围要我们单独做(k=1),那么把((sum_{i=0} {i+1over i!}x^i)^n[x^{n-2}])单独拿出来
[(sum_{i=0} {i+1over i!}x^i)^n[x^{n-2}]
]
用(e^x)代替
[(xe^x+e^x)^n[x^{n-2}]
]
二项式定理展开
[[x^{n-2}]sum {nchoose i} x^ie^{ix}e^{(n-i)x}
]
合并
[[x^{n-2}]sum {nchoose i} x^ie^{nx}
]
化简一下
[sum {nchoose i} e^{nx}[x^{n-2-i}]
]
泰勒展开一下
[sum {nchoose i} {n^{n-2-i}over (n-2-i)!}
]
就可以(O(n))算了。
代码:(很长)
//@winlere
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
//#define getchar() (__c==__ed?(__ed=__buf+fread(__c=__buf,1,1<<18,stdin),*__c++):*__c++)
using namespace std; typedef long long ll; char __buf[1<<18],*__c=__buf,*__ed=__buf;
inline int qr(){
int ret=0,f=0,c=getchar();
while(!isdigit(c)) f|=c==45,c=getchar();
while( isdigit(c)) ret=ret*10+c-48,c=getchar();
return f?-ret:ret;
}
const int maxn=1<<22;
const int mod=998244353;
const int g=3;
const int gi=(mod+1)/3;
inline int MOD(const int&x){return x>=mod?x-mod:x;}
inline int MOD(const int&x,const int&y){return 1ll*x*y%mod;}
int invs[maxn],jc[maxn],inv[maxn];
int ksm(const int&ba,const int&p){
int ret=1;
for(int t=p,b=ba;t;t>>=1,b=MOD(b,b))
if(t&1) ret=MOD(ret,b);
return ret;
}
void NTT(int*a,const int&len,const int&tag){
static int r[maxn];
for(int t=1;t<len;++t)
if((r[t]=r[t>>1]>>1|(t&1?len>>1:0))>t)
swap(a[t],a[r[t]]);
for(int t=1,s=tag==1?g:gi,wn;t<len;t<<=1){
wn=ksm(s,(mod-1)/(t<<1));
for(int i=0;i<len;i+=t<<1)
for(int j=0,w=1,p;j<t;++j,w=MOD(w,wn))
p=MOD(a[i+j+t],w),a[i+j+t]=MOD(a[i+j]-p+mod),a[i+j]=MOD(a[i+j]+p);
}
if(tag!=1)
for(int t=0,i=mod-(mod-1)/len;t<len;++t)
a[t]=MOD(a[t],i);
}
void Deri(int*a,const int&len){
for(int t=0;t<len-1;++t) a[t]=MOD(a[t+1],t+1);
a[len-1]=0;
}
void Inter(int*a,int*b,const int&len){
for(int t=len-1;t;--t) b[t]=MOD(a[t-1],invs[t]);
b[0]=0;
}
void INV(int*a,int*b,const int&len){
if(len==1) return b[0]=ksm(a[0],mod-2),void();
INV(a,b,len>>1);
static int A[maxn],B[maxn];
memset(A,0,len<<3); memset(B,0,len<<3);
memcpy(A,a,len<<2); memcpy(B,b,len<<2);
NTT(A,len<<1,1); NTT(B,len<<1,1);
for(int t=0;t<len<<1;++t) A[t]=MOD(B[t],MOD(A[t],B[t]));
NTT(A,len<<1,0);
for(int t=0;t<len;++t) b[t]=MOD(MOD(b[t]+b[t])-A[t]+mod);
}
void LN(int*a,int*b,const int&len){
static int A[maxn],B[maxn];
memset(A,0,len<<3); memset(B,0,len<<3);
memcpy(A,a,len<<2);
INV(A,B,len); Deri(A,len);
NTT(A,len<<1,1); NTT(B,len<<1,1);
for(int t=0;t<len<<1;++t) A[t]=MOD(A[t],B[t]);
NTT(A,len<<1,0);
Inter(A,b,len);
}
void EXP(int*a,int*b,const int&len){
if(len==1){b[0]=1;return;}
EXP(a,b,len>>1);
static int A[maxn],B[maxn];
memset(A,0,len<<3); memset(B,0,len<<3);
memcpy(A,b,len<<1); LN(b,B,len);
for(int t=0;t<len;++t) B[t]=MOD(a[t]-B[t]+mod);
++B[0];
NTT(A,len<<1,1); NTT(B,len<<1,1);
for(int t=0;t<len<<1;++t) A[t]=MOD(A[t],B[t]);
NTT(A,len<<1,0);
for(int t=0;t<len;++t) b[t]=A[t];
}
void POW(int*a,int*b,const int&len,const int&k){
static int A[maxn],B[maxn];
memset(A,0,len<<3); memcpy(A,a,len<<2);
memset(B,0,len<<3);
LN(A,A,len);
for(int t=0;t<len;++t) A[t]=MOD(A[t],k);
EXP(A,B,len);
memcpy(b,B,len<<2);
}
void pre(const int&n){
jc[0]=invs[1]=inv[0]=1;
for(int t=1;t<=n;++t) jc[t]=MOD(jc[t-1],t);
for(int t=2;t<=n;++t) invs[t]=MOD(mod-mod/t,invs[mod%t]);
for(int t=1;t<=n;++t) inv[t]=MOD(inv[t-1],invs[t]);
//for(int t=0;t<=n;++t) if(MOD(jc[t],inv[t])!=1) puts("wa"),cerr<<"t="<<t<<endl;
}
int c(const int&n,const int&m){
if(n<m) return 0;
return MOD(jc[n],MOD(inv[m],inv[n-m]));
}
int main(){
pre(maxn-1);
#ifdef debug
static int test[maxn],tesv[maxn];
for(int t=0;t<4;++t) test[t]=t+1;
NTT(test,8,1);
NTT(test,8,0);
for(int t=0;t<16;++t) fprintf(stderr,"%d%c",test[t],"
"[t==15]);
INV(test,tesv,4);
NTT(test,8,1); NTT(tesv,8,1);
for(int t=0;t<8;++t) test[t]=MOD(test[t],tesv[t]);
NTT(test,8,0);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t],"
"[t==7]);
for(int t=0;t<8;++t) test[t]=0;
for(int t=0;t<4;++t) test[t]=t+1;
Deri(test,4);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t],"
"[t==7]);
Inter(test,test,4);
Deri(test,4);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t],"
"[t==7]);
for(int t=0;t<8;++t) test[t]=0;
for(int t=0;t<4;++t) test[t]=t+1;
POW(test,test,8,2);
for(int t=0;t<8;++t) fprintf(stderr,"%d%c",test[t],"
"[t==7]);
#endif
int T=qr();
while(T--){
int n=qr(),k=qr();
if(k!=1){
static int a[maxn];
int len=1;
while(len<=n) len<<=1;
memset(a,0,len<<2);
for(int t=0;t<=n;++t) a[t]=MOD(inv[t],ksm(t+1,k));
POW(a,a,len,n);
//cerr<<"a[n-2]="<<a[n-2]<<endl;
int ans=MOD(MOD(a[n-2],jc[n-2]),ksm(ksm(n,n-2),mod-2));
printf("%d
",ans);
}else{
int ans=0;
for(int t=0;t<=n-2;++t)
ans=MOD(ans+MOD(c(n,t),MOD(ksm(n,n-2-t),inv[n-2-t])));
ans=MOD(ans,MOD(jc[n-2],ksm(ksm(n,n-2),mod-2)));
printf("%d
",ans);
}
}
return 0;
}