有个多项式(f(x)=sum_{i=0}^na_ix^i)。规定(i=0..n,f(i)in[1,c_i])。问有多少组合法的(a_i)((a_i)都是整数)
(nle 6)
为了方便把范围定成([0,c_i-1])。
首先写成下降幂多项式,好处是:(f(i))的表达式可以由(a_{0..i})确定。设(f(x)=sum_{i=0}^n q_i x^{underline i})。
假如决定了(q_{0..i-1}),对于(f(i))的表达式,要求:(0le q_ii!+sum_{j=0}^{i-1}q_ji^{underline j}<c_i)。设(C=sum_{j=0}^{i-1}q_ji^{underline j}),则解的个数为(lfloorfrac{c_i}{i!} floor+[c_imod i!>cmod i!])。
注意到(q_ji^{underline j}=(q_jmod (i-j)!)i^{underline j} pmod {i!})。这样我们就将需要知道的(q_j)的取值限定了。
设(r_i=q_imod (n-i)!)。枚举(r_i)。
于是对于(f(i)),令(q_i=(n-i)!t+r_i),则要求:(0le (n-i)!i!t+r_ii!+sum_{j=0}^{i-1}q_ji^{underline j}<c_i)。
类似地,需要求出((r_ii!+sum_{j=0}^{i-1}q_ji^{underline j})mod (n-i)!i!),它等于(sum_{j=0}^i r_ii^{underline j}mod (n-i)!i!)。(因为((n-i)!i!|(n-j)!i^{underline j}),相除就是个组合数)
于是就得到了:(O(nprod_{i=0}^{n} i!))的时间。
然后可以最后枚举(r_0)。可以发现根据每个位置的贡献不同将(r_0)分段,位置(i)会分出(O(frac{n}{(n-i)!i!}))段,加起来刚好就是(O(2^n))。搞一搞可以将时间中的一个(n!)换成(2^nlg 2^n)。题解说这个(lg)可以去掉说是用计数排序(好像要离线一下排序量比较大的时候才做)。
然而感觉常数才是最大的问题,卡不过去,TLE90爬
using namespace std;
#include <bits/stdc++.h>
#define ll long long
#define fi first
#define se second
#define mp make_pair
const int N=6;
const int mo=998244353;
const double _mo=1.0/mo;
ll qpow(ll x,ll y=mo-2){
ll r=1;
for (;y;y>>=1,x=x*x%mo)
if (y&1)
r=r*x%mo;
return r;
}
int n;
ll fac[N+1];
ll c[N+1];
ll r[N+1],C[N+1],f[N+1],tran[(N+1)*2],g[N+1],b[N+1];
ll ans;
pair<int,int> q[1<<N+2];
void dfs(int i){
if (i>n){
int k=0;
q[k++]=mp(0,0);
q[k++]=mp(g[0],0);
for (int j=1;j<=n;++j){
ll m=fac[j]*fac[n-j];
for (int t=0;t*m-C[j]<fac[n];++t){
q[k++]=mp(t*m-C[j],j);
q[k++]=mp(t*m+g[j]-C[j],j);
}
}
sort(q,q+k);
//return;
int cnt0=0;
ll pro=1;
for (int j=0;j<=n;++j){
b[j]=0;
if (f[j]==0) cnt0++; else pro=pro*f[j]%mo;
}
ll lst=0;
for (int t=0;t<k && q[t].fi<fac[n];++t){
if (lst<q[t].fi){
if (!cnt0)
ans+=pro*(q[t].fi-lst);
lst=q[t].fi;
}
int w=q[t].se;
if (f[w]){
pro=pro*tran[w<<1|b[w]]%mo;
//pro*=tran[w<<1|b[w]];
//pro=pro-(ll)(pro*_mo)*mo;
}
else
cnt0+=(b[w]?1:-1);
b[w]^=1;
}
if (!cnt0)
ans+=pro*(fac[n]-lst);
ans%=mo;
return;
}
C[i]=0;
ll m=fac[i]*fac[n-i];
for (int j=0;j<i;++j)
C[i]=(C[i]+r[j]*(fac[i]/fac[i-j]))%m;
for (int v=0;v<fac[n-i];++v){
r[i]=v;
dfs(i+1);
C[i]=(C[i]+fac[i])%m;
}
}
int main(){
fac[0]=1;
for (int i=1;i<=N;++i)
fac[i]=fac[i-1]*i;
int T;
scanf("%d",&T);
while (T--){
scanf("%d",&n);
for (int i=0;i<=n;++i)
scanf("%lld",&c[i]);
ans=0;
for (int i=0;i<=n;++i){
f[i]=c[i]/(fac[i]*fac[n-i]);
tran[i<<1]=qpow(f[i])*(f[i]+1)%mo;
tran[i<<1|1]=qpow(f[i]+1)*(f[i])%mo;
g[i]=c[i]%(fac[i]*fac[n-i]);
}
dfs(1);
ans%=mo;
printf("%lld
",ans);
}
return 0;
}