原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ269.html
题目传送门 - UOJ269
题意
有一个多项式函数 $f(x)$,最高次幂为 $x^m$,定义变换 $Q$:
$$Q(f,n,x)=sum_{k=0}^n f(k)inom nk x^k(1−x)^{n−k}$$
现在给定函数 $f$ 和 $n,x$,求 $Q(f,n,x)mod {
m 998244353}$。
$f(x)$ 由 $0$~$m$ 的点值给出。
$1leq nleq 10^9,1leq m leq 2 imes 10^4, 0leq a_i,x <998244353$
题解
cly_none 太强了。
考虑一个 $m$ 次多项式 $f(x)$ ,必然可以拆成一堆下降幂的和。(忽略系数)其中,最高次项是 $m$ 次项,所以转成下降幂之后,最高次项就是一个 $m$ 阶下降幂。
对于 $f(x)$ 的某一个下降幂表示,设为 $x^underline{k}$ ,那么,可以得到:
$$egin{aligned} & sum_{i=0}^n i^{underline k} {nchoose i} x^i (1-x)^{n-i} \ = & sum_{i=k}^n i^{underline k} frac {n^{underline k}}{i ^ {underline k}} {n - kchoose i - k } x^i (1-x)^{n-i} \ = & n^{underline k} sum_{i=k}^n {n - kchoose i - k } x^i (1-x)^{n-i} \ = & n^{underline k} x^k sum_{i=0}^{n-k} {n - kchoose i} x^i (1-x)^{n-k-i} \ = & n^{underline k} x^k end{aligned}$$
于是这样就证明了题目要求的式子是一个关于 $n$ 的 $m$ 次多项式。
于是只需要 FFT 一下,求出 $[0,m]$ 之间的整点的点值,然后插值来求答案。由于这些点值十分特殊,所以可以预处理阶乘来 $O(m)$ 求解。
总的时间复杂度为 $O(mlog m)$ 。
代码
#include <bits/stdc++.h> using namespace std; const int N=1<<16,mod=998244353; int read(){ int x=0; char ch=getchar(); while (!isdigit(ch)) ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return x; } int Pow(int x,int y){ int ans=1; for (;y;y>>=1,x=1LL*x*x%mod) if (y&1) ans=1LL*ans*x%mod; return ans; } int n,m,x,a[N],A[N],B[N]; int Fac[N],Inv[N]; int w[N],R[N]; int C(int n,int m){ if (m>n||m<0) return 0; return 1LL*Fac[n]*Inv[m]%mod*Inv[n-m]%mod; } void FFT(int a[],int n){ for (int i=0;i<n;i++) if (R[i]<i) swap(a[R[i]],a[i]); for (int t=n>>1,d=1;d<n;d<<=1,t>>=1) for (int i=0;i<n;i+=(d<<1)) for (int j=0;j<d;j++){ int tmp=1LL*w[t*j]*a[i+j+d]%mod; a[i+j+d]=(a[i+j]+mod-tmp)%mod; a[i+j]=(a[i+j]+tmp)%mod; } } void Mul(int a[],int b[],int m){ int n,d; for (n=1,d=0;n<=m*2+2;n<<=1,d++); for (int i=0;i<n;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(d-1)); w[0]=1,w[1]=Pow(3,(mod-1)/n); for (int i=2;i<n;i++) w[i]=1LL*w[i-1]*w[1]%mod; FFT(a,n); FFT(b,n); for (int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%mod; w[0]=1,w[1]=Pow(w[1],mod-2); for (int i=2;i<n;i++) w[i]=1LL*w[i-1]*w[1]%mod; FFT(a,n); int inv=Pow(n,mod-2); for (int i=0;i<n;i++) a[i]=1LL*a[i]*inv%mod; } int calc(int n){ int ans=0; for (int k=0;k<=n;k++) ans=(1LL*a[k]*C(n,k)%mod*Pow(x,k)%mod*Pow(mod+1-x,n-k)%mod+ans)%mod; return ans; } int main(){ n=read(),m=read(),x=read(); for (int i=0;i<=m;i++) a[i]=read(); for (int i=Fac[0]=Inv[0]=1;i<=m;i++){ Fac[i]=1LL*Fac[i-1]*i%mod; Inv[i]=1LL*Inv[i-1]*Pow(i,mod-2)%mod; } if (n<=m) return printf("%d ",calc(n)),0; for (int i=0;i<=m;i++){ A[i]=1LL*a[i]*Inv[i]%mod*Pow(x,i)%mod; B[i]=1LL*Inv[i]*Pow(mod+1-x,i)%mod; } Mul(A,B,m); for (int i=0;i<=m;i++) A[i]=1LL*A[i]*Fac[i]%mod; int ans=0; for (int i=0;i<=m;i++){ int t=1LL*A[i]*Inv[i]%mod*Inv[m-i]%mod; t=1LL*t*Pow(n+mod-i,mod-2)%mod; if ((m-i)&1) t=(mod-t)%mod; ans=(ans+t)%mod; } for (int i=n;i>=n-m;i--) ans=1LL*ans*i%mod; printf("%d",ans); return 0; }