对每个 (a_i,) 建出一个多项式 (F(x) = sumlimits_{j=1}^{a_i} x^j inom{a_i-1}{j-1},) (j)次项系数表示这些卡牌被分成(j)段的方案数,也表示它们有(a_i-j)处强制为魔术对的方案数。
对它们进行(EGF)卷积,最后的结果(G(x))的(n-i)次项系数 ([x^{n-i}]G(x)) 即为强制有(i)处为魔术对的方案数。
最后二项式反演即可得到答案为 (ans = sumlimits_{i=k}^{n} (-1)^{i-k} inom{i}{k} [x^{n-i}]G(x).)
怎么求出把(m)个长度之和(=n)的多项式的卷积的结果呢?
用一个堆记录当前多项式,每次找长度最短的那两个卷积起来即可,可以证明复杂度不超过(Theta (nlog^2 n))
code :
#include <bits/stdc++.h>
#define LL long long
using namespace std;
template <typename T> void read(T &x){
static char ch; x = 0,ch = getchar();
while (!isdigit(ch)) ch = getchar();
while (isdigit(ch)) x = x * 10 + ch - '0',ch = getchar();
}
inline void write(int x){if (x > 9) write(x/10); putchar(x%10+'0'); }
const int P = 998244353,g = 3,L = 131072,M = 20050,N = 100050;
inline int power(int x,int y){
static int r; r = 1; while (y){ if (y&1) r = (LL)r * x % P; x = (LL)x * x % P,y >>= 1; }
return r;
}
int rt[30],irt[30],R[L];
int inv[L+5],fac[L+5],nfac[L+5];
inline int C(int n,int m){
return (n<0||m<0||n<m) ? 0 : ((LL)fac[n] * nfac[m] % P * nfac[n-m]) % P;
}
inline int getR(int n){
static int i,l,Lim; l = 0,Lim = 1; while (Lim <= n) Lim <<= 1,++l;
for (i = 1; i < Lim; ++i) R[i] = (R[i>>1]>>1) | ((i&1)<<l-1);
return Lim;
}
inline void NTT(int *A,int n){
register int i,j,k,l,w,w0,x;
for (i = 1; i < n; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
for (i = l = 1; i < n; i <<= 1,++l)
for (w0 = rt[l],j = 0; j < n; j += i<<1)
for (w = 1,k = j; k < i+j; ++k,w = (LL)w * w0 % P)
x = (LL)w * A[k+i] % P,A[k+i] = (A[k]<x)?(A[k]+P-x):(A[k]-x),
A[k] = (A[k]+x>=P)?(A[k]+x-P):(A[k]+x);
}
inline void iNTT(int *A,int n){
register int i,j,k,l,w,w0,x;
for (i = 1; i < n; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
for (i = l = 1; i < n; i <<= 1,++l)
for (w0 = irt[l],j = 0; j < n; j += i<<1)
for (w = 1,k = j; k < i+j; ++k,w = (LL)w * w0 % P)
x = (LL)w * A[k+i] % P,A[k+i] = (A[k]<x)?(A[k]+P-x):(A[k]-x),
A[k] = (A[k]+x>=P)?(A[k]+x-P):(A[k]+x);
for (i = 0,w = inv[n]; i < n; ++i) A[i] = (LL)A[i] * w % P;
}
typedef vector<int> arr;
int F[L],G[L];
inline void Mul(arr &A,arr &B,arr &C){
int n = A.size()-1,m = B.size()-1,Li = getR(n+m); register int i;
for (memset(F,0,Li<<2),i = 0; i <= n; ++i) F[i] = A[i];
for (memset(G,0,Li<<2),i = 0; i <= m; ++i) G[i] = B[i];
NTT(F,Li); NTT(G,Li); for (i = 0; i < Li; ++i) F[i] = (LL)F[i] * G[i] % P; iNTT(F,Li);
C.resize(n+m+1); for (i = 0; i <= n+m; ++i) C[i] = F[i];
}
arr A[M<<1]; int cnto;
struct Node{
int id,len;
bool operator < (const Node t) const{ return len > t.len; }
}tmp;
priority_queue<Node>H;
int m,n,k,a[M];
inline void build(int n){
A[++cnto].resize(n+1);
for (int i = 0; i <= n; ++i) A[cnto][i] = (LL)nfac[i] * C(n-1,i-1) % P;
}
inline void work(){
int i,id1,id2;
for (i = 1; i <= cnto; ++i) tmp.id = i,tmp.len = A[i].size(),H.push(tmp);
while (H.size() > 1){
id1 = H.top().id,H.pop(),id2 = H.top().id,H.pop();
Mul(A[id1],A[id2],A[++cnto]);
tmp.id = cnto,tmp.len = A[cnto].size(),H.push(tmp);
}
for (i = 0; i <= n; ++i) F[n-i] = (LL)fac[i] * A[cnto][i] % P;
}
int main(){
int i,j;
for (i = 1,j = 2; i <= 25; ++i,j <<= 1) rt[i] = power(3,(P-1)/j),irt[i] = power(rt[i],P-2);
inv[0] = inv[1] = nfac[0] = fac[0] = nfac[1] = fac[1] = 1;
for (i = 2; i <= L; ++i){
fac[i] = (LL)fac[i-1] * i % P;
inv[i] = (LL)(P-P/i) * inv[P%i] % P;
nfac[i] = (LL)nfac[i-1] * inv[i] % P;
}
read(m),read(n),read(k);
for (i = 1; i <= m; ++i) read(a[i]),build(a[i]);
work();
int ans = 0;
for (i = k; i <= n; ++i){
if ((k-i) & 1) ans = (ans + P - (LL)C(i,k) * F[i] % P) % P;
else ans = (ans + (LL)C(i,k) * F[i] % P) % P;
}
cout << ans << '
';
return 0;
}