题目:https://www.lydsy.com/JudgeOnline/problem.php?id=5093
不要见到组合数就拆!
枚举每个点的度数,则答案为 ( n*sumlimits_{i=0}^{n-1}C_{n-1}^{i}*2^{C_{n-1}^{2}}*i^{k} )
(又是那个公式:( x^{n}=sumlimits_{k=0}^{n}C_{x}^{k}*(k!)*S(n,k) ))
( = n*2^{C_{n-1}^{2}}sumlimits_{i=0}^{n-1}C_{n-1}^{i}sumlimits_{j=0}^{k}C_{i}^{j}*(j!)*S(k,j) )
这里发现组合数的角标有一样的,不要把那两个组合数拆了以消掉阶乘,而可以通过组合意义把它们合起来!
( = n*2^{C_{n-1}^{2}}sumlimits_{j=0}^{k}(j!)*S(k,j)sumlimits_{i=0}^{n-1}C_{n-1}^{i}*C_{i}^{j} )
从 n-1 个数里选 i 个数,再从 i 个数里选 j 个数,而且 i 从 0 枚举到 n-1 ,就可以看作从 n-1 个数里选了 j 个数,剩下 n-1-j 个数可选可不选。
(比如一个点 1 想连到另一个点 2 , 1 先在 n-1 个点里选 i 个点连上,再从这 i 个点里选 j 个点连到点 2 ; 也即点 1 在 n-1 个点里选了 j 个点连向点 2 ,其余的点可能和点 1 相连)
所以 ( = n*2^{C_{n-1}^{2}}sumlimits_{j=0}^{k}(j!)*S(k,j)*C_{n-1}^{j}*2^{n-1-j} )
用 NTT 预处理斯特林数就行了。别把现在的组合数拆掉,因为是一维特别大,一维特别小,所以分子和分母消一下。需要预处理阶乘和下降幂,才能做到 O(1) 算组合数。
注意指数上模 mod-1 。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=2e5+5,M=(1<<19)+5,mod=998244353; int n,m,s[N],a[M],b[M],jcn[N],ljc[N],len,r[M]; void upd(int &x){x>=mod?x-=mod:0;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} void ntt(int *a,bool fx) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { int wn=pw( 3,fx?(mod-1)-(mod-1)/R:(mod-1)/R ); for(int i=0,m=R>>1;i<len;i+=R) for(int j=0,w=1;j<m;j++,w=(ll)w*wn%mod) { int x=a[i+j], y=(ll)w*a[i+m+j]%mod; a[i+j]=x+y; upd(a[i+j]); a[i+m+j]=x+mod-y; upd(a[i+m+j]); } } if(!fx)return ; int inv=pw(len,mod-2); for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod; } void init() { jcn[0]=1;for(int i=1;i<=m;i++)jcn[i]=(ll)jcn[i-1]*i%mod; jcn[m]=pw(jcn[m],mod-2);for(int i=m-1;i>=0;i--)jcn[i]=(ll)jcn[i+1]*(i+1)%mod; for(int i=0,j=1;i<=m;i++,j=-j) a[i]=j*jcn[i]+mod,upd(a[i]); for(int i=0;i<=m;i++) b[i]=(ll)pw(i,m)*jcn[i]%mod; for(len=1;len<=m<<1;len<<=1); for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0); ntt(a,0); ntt(b,0); for(int i=0;i<len;i++)a[i]=(ll)a[i]*b[i]%mod; ntt(a,1); for(int i=0;i<=m;i++)s[i]=a[i]; ljc[0]=1;for(int i=n-1,j=1;j<=m;j++,i--)ljc[j]=(ll)ljc[j-1]*i%mod; } int C(int m) { return (ll)ljc[m]*jcn[m]%mod; } int main() { scanf("%d%d",&n,&m); init(); int ans=0; for(int i=0,jc=1;i<=m;i++,jc=(ll)jc*i%mod) { if(i>n-1)break;//or pw() ans=(ans+(ll)jc*C(i)%mod*pw(2,n-1-i)%mod*s[i])%mod; } ans=(ll)ans*n%mod*pw(2,(ll)(n-1)*(n-2)/2%(mod-1))%mod;//mod-1 printf("%d ",ans); return 0; }