定义集合$S$合法当且仅当:$Ssubseteq [1,n]$,$|S|=k$且$forall iin [d,n],|Scap(i-d,i]|le 1$
问题即求$sum_{S合法}sum_{xin S}a_{x}$
记$F(n,k)=sum_{S合法}1$和$G(n,k,i)=sum_{iin S且S合法}1$(合法指对此时的$n$和$k$,$d$是常数),交换枚举顺序,答案即$sum_{i=1}^{n}sum_{iin S且S合法}1=sum_{i=1}^{n}a_{i}G(n,k,i)$
(特别的,当$i otin [1,n]$时,定义$G(n,k,i)=0$)
关于$G(n,k,i)$,由于$iin S$,因此在$(i-d,i)$和$(i,i+d)$中不能有其他元素,那么不妨将$[i,i+d)$这一段全部删去(不能把$(i-d,i+d)$都删去,因为$i-d$和$i+d$是可以同时选的)
此时,方案数可以容斥计算,即先不考虑$(i-d,i)$中不能有元素的限制,方案数即$F(n-d,k-1)$,再枚举其中的元素,显然至多仅有1个,即$sum_{j=i-d+1}^{i-1}G(n-d,k-1,j)$
综上,得到递推式——
$$
G(n,k,i)=F(n-d,k-1)-sum_{j=i-d+1}^{i-1}G(n-d,k-1,j)
$$
(特别的,当$i
otin [d,n+1-d]$时,可以类似地分析,最终的式子是相同的)
不难发现,前两维一定是形如$(n-td,k-t)$的形式,记$V_{t}(x)=sum_{i=1}^{n-td}G(n-td,k-t,i)x^{i}$,根据上面的递推式有
$$
V_{t}(x)=F(n-td-d,k-t-1)sum_{i=1}^{n-td}x^{i}-sum_{i=1}^{d-1}x^{i}V_{t+1}(x)
$$
为了方便,以下记$C=F(n-td-d,k-t-1)$
对于前者,维护懒标记$tag_{t}$,并令$V_{t}(x)=V'_{t}(x)+tagsum_{i=1}^{n-td}x^{i}$,我们记录$(V'_{t}(x),tag_{t})$,考虑维护这个二元组的转移,即
$$
V'_{t}(x)+tag_{t}sum_{i=1}^{n-td}x^{i}=Csum_{i=1}^{n-td}x^{i}-sum_{i=1}^{d-1}x^{i}(V'_{t+1}(x)+tag_{t+1}sum_{j=1}^{n-td-d}x^{j})
$$
(不难发现每一个$tag_{t}$都可以对应一个$V'_{t}(x)$,我们只需要维护其中一种即可)
对其展开并调整使两边形式类似,即
$$
V'_{t}(x)+tag_{t}sum_{i=1}^{n-td}x^{i}=-sum_{i=1}^{d-1}x^{i}V'_{t+1}(x)+Csum_{i=1}^{n-td}x^{i}-tag_{t+1}sum_{i=1}^{d-1}x^{i}sum_{j=1}^{n-td-d}x^{j}
$$
对于最后一项,其$i$次项系数即$min(i,n-td-i+1,d)-1$(其中$iin [1,n-td]$),将系数补成$d-1$,再将抵消的项放到$V'_{t}(x)$中,并将两项对应提出,即
$$
egin{cases}V'_{t}(x)=tag_{t+1}sum_{i=1}^{d-1}(d-i)(x^{i}+x^{n-td-i+1})-sum_{i=1}^{d-1}x^{i}V'_{t+1}(x)\tag_{t}=C-(d-1)tag_{t+1}end{cases}
$$
由于$kle lceilfrac{n}{d}
ceil$,即$(k-1)dle n$,那么初始即$(V'_{k-1}(x),tag_{k-1})=(0,1)$,最终即求$(V'_{0}(x),tag_{0})$
由于$tag_{t}$与$V'_{t}$无关,先来计算$tag_{t}$,即需要求$F(n,k)$
关于$F(n,k)$,即$sum_{i=0}^{k}x_{i}=n-k$的方案数,其中$x_{i}in N$且$forall 1le i<k,x_{i}ge d-1$
将中间这$k-1$项每一项都减去$d-1$,再用插板法可得$F(n,k)={n-(k-1)(d-1)choose k}$,$o(1)$计算即可
由此,即可在$o(n)$的时间内求得$tag_{t}$
令$h(x)=-sum_{i=1}^{d-1}x^{i}$和$h_{t}(x)=tag_{t+1}sum_{i=1}^{d-1}(d-i)(x^{i}+x^{n-td-i+1})$,考虑直接求$V'_{0}(x)$的通项公式,通过计算每一次$h_{t}(x)$的影响,即$V'_{0}(x)=sum_{i=0}^{k-2}h_{i}(x)h^{i}(x)$
将$h_{t}(x)$拆为$A_{t}(x)=tag_{t+1}sum_{i=1}^{d-1}(d-i)x^{i}$和$B_{t}(x)=tag_{t+1}sum_{i=1}^{d-1}(d-i)x^{n-td-i+1}$,根据分配律,将两者分别计算,即$V'_{0}(x)=sum_{i=0}^{k-2}A_{i}(x)h^{i}(x)+sum_{i=0}^{k-2}B_{i}(x)h^{i}(x)$
分治+FFT计算即可,其中对于$B_{i}(x)$提取一个$x^{n-td-d+1}$这个公因式就可以降低次数
最终时间复杂度为$o(nlog nlog k)$,被卡常了
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 1000005 4 #define mod 998244353 5 struct Poly{ 6 int m; 7 vector<int>a; 8 Poly(){ 9 m=0; 10 a.clear(); 11 } 12 }ans,h[25],A[N<<1]; 13 int n,k,d,x,sum,fac[N],inv[N],tag[N<<1]; 14 int C(int n,int m){ 15 if (n<m)return 0; 16 return 1LL*fac[n]*inv[m]%mod*inv[n-m]%mod; 17 } 18 int F(int n,int k){ 19 return C(n-(k-1)*(d-1),k); 20 } 21 int Log2(int n){ 22 int m=0; 23 while ((1<<m)<n)m++; 24 return m; 25 } 26 int qpow(int n,int m){ 27 int s=n,ans=1; 28 while (m){ 29 if (m&1)ans=1LL*ans*s%mod; 30 s=1LL*s*s%mod; 31 m>>=1; 32 } 33 return ans; 34 } 35 void Add(Poly &a,int n){ 36 while (a.a.size()<n)a.a.push_back(0); 37 } 38 void Dec(Poly &a){ 39 while ((a.a.size())&&(!a.a.back()))a.a.pop_back(); 40 } 41 void ntt(Poly &a,int n,int p){ 42 for(int i=0;i<(1<<n);i++){ 43 int s=0; 44 for(int j=0;j<n;j++) 45 if (i&(1<<j))s+=(1<<n-j-1); 46 if (i<s)swap(a.a[i],a.a[s]); 47 } 48 for(int i=2;i<=(1<<n);i<<=1){ 49 int s=qpow(3,(mod-1)/i); 50 if (p)s=qpow(s,mod-2); 51 for(int j=0;j<(1<<n);j+=i) 52 for(int k=0,ss=1;k<(i>>1);k++,ss=1LL*ss*s%mod){ 53 int x=a.a[j+k],y=1LL*a.a[j+k+(i>>1)]*ss%mod; 54 a.a[j+k]=(x+y)%mod; 55 a.a[j+k+(i>>1)]=(x+mod-y)%mod; 56 } 57 } 58 if (p){ 59 int s=qpow((1<<n),mod-2); 60 for(int i=0;i<a.a.size();i++)a.a[i]=1LL*a.a[i]*s%mod; 61 } 62 } 63 Poly add(Poly x,Poly y){ 64 Poly ans; 65 if (x.m>y.m)swap(x,y); 66 ans=x; 67 y.m-=x.m; 68 Add(ans,y.m+y.a.size()); 69 for(int i=0;i<y.a.size();i++)ans.a[i+y.m]=(ans.a[i+y.m]+y.a[i])%mod; 70 return ans; 71 } 72 Poly mul(Poly x,Poly y){ 73 Poly ans; 74 ans.m=x.m+y.m; 75 int n=Log2(x.a.size()+y.a.size()); 76 Add(x,(1<<n)),Add(y,(1<<n)); 77 ntt(x,n,0); 78 ntt(y,n,0); 79 for(int i=0;i<(1<<n);i++)ans.a.push_back(1LL*x.a[i]*y.a[i]%mod); 80 ntt(ans,n,1); 81 Dec(ans); 82 return ans; 83 } 84 Poly calc(int k,int l,int r){ 85 if (!k)return A[l]; 86 int mid=(l+r>>1); 87 return add(calc(k-1,l,mid),mul(calc(k-1,mid+1,r),h[k-1])); 88 } 89 int main(){ 90 fac[0]=inv[0]=inv[1]=1; 91 for(int i=1;i<N;i++)fac[i]=1LL*fac[i-1]*i%mod; 92 for(int i=2;i<N;i++)inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod; 93 for(int i=1;i<N;i++)inv[i]=1LL*inv[i-1]*inv[i]%mod; 94 scanf("%d%d%d",&n,&k,&d); 95 tag[k-1]=1; 96 for(int i=k-2;i>=0;i--)tag[i]=(F(n-i*d-d,k-i-1)-1LL*(d-1)*tag[i+1]%mod+mod)%mod; 97 h[0].a.push_back(0); 98 for(int i=1;i<d;i++)h[0].a.push_back(mod-1); 99 int kk=Log2(k); 100 for(int i=1;i<kk;i++)h[i]=mul(h[i-1],h[i-1]); 101 for(int i=0;i<(1<<kk);i++) 102 if (tag[i+1]){ 103 A[i].a.push_back(0); 104 for(int j=1;j<d;j++)A[i].a.push_back(1LL*tag[i+1]*(d-j)%mod); 105 } 106 ans=calc(kk,0,(1<<kk)-1); 107 for(int i=0;i<(1<<kk);i++){ 108 A[i].m=max(n-i*d-d+1,0); 109 A[i].a.clear(); 110 if (tag[i+1]){ 111 A[i].a.push_back(0); 112 for(int j=1;j<d;j++)A[i].a.push_back(1LL*tag[i+1]*j%mod); 113 } 114 } 115 ans=add(ans,calc(kk,0,(1<<kk)-1)); 116 Add(ans,n+1); 117 for(int i=1;i<=n;i++)ans.a[i]=(ans.a[i]+tag[0])%mod; 118 for(int i=1;i<=n;i++){ 119 scanf("%d",&x); 120 sum=(sum+1LL*x*ans.a[i])%mod; 121 } 122 printf("%d",sum); 123 }