归纳每一次操作后必然是两个颜色相同的连续段(即ww...bb...或bb...ww...),对操作的位置分类讨论不难证明正确性
当$c_{1}=c_{n}$,由于端点颜色不会修改,再根据该结论,可以得到$f(s,c_{i})=c_{1}cdot n$(w为0,b为$n$)
当$c_{1} e c_{n}$(以下假设$c_{1}=b$且$c_{n}=w$),令$x=min_{c_{i}=w}i$且$y=max_{c_{i}=b}i$,考虑答案的上下限,最坏情况下即为$[1,x)$,最好情况下为$[1,y]$
对$x-1$和$y+1$哪个先取分类讨论:
1.若$x-1$先取,则当选择$y$后(也有可能是先选$y$再选$x-1$,但同理),$[1,y]$必然都为黑色(且不会再被翻转),即达到上限,因此如果$x-1$到$s$的距离小于等于$s$到$y+1$的距离,则答案为$[1,y]$
2.若$y+1$先取,类似的可以得到$[x,n]$都为白色,即答案取到下限$[1,x)$,因此如果$x-1$到$s$的距离大于$s$到$y+1$的距离,则答案为$[1,x)$
(另外对于$c_{1}=w$且$c_{n}=b$,不是两倍而是$s'=n-s+1$时的答案)
这显然包含了所有情况,即答案仅与$x$、$y$和$s$有关,得到了一个暴力$o(n^{3})$的做法,代码如下
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 200005 4 #define mod 998244353 5 int n,mi[N],ans[N]; 6 int main(){ 7 scanf("%d",&n); 8 if (n==1){ 9 printf("%d",(mod+1)/2); 10 return 0; 11 } 12 mi[0]=1; 13 for(int i=1;i<=n;i++)mi[i]=mi[i-1]*2%mod; 14 for(int i=2;i<=n;i++){ 15 for(int j=1;j<=n;j++)ans[j]=(ans[j]+i-1)%mod; 16 for(int j=i+1;j<n;j++) 17 for(int k=1;k<=n;k++) 18 if (abs(i-1-k)<=abs(j+1-k))ans[k]=(ans[k]+1LL*j*mi[j-i-1])%mod; 19 else ans[k]=(ans[k]+1LL*(i-1)*mi[j-i-1])%mod; 20 } 21 int inv=1; 22 for(int i=1;i<=n;i++)inv=1LL*(mod+1)/2*inv%mod; 23 for(int i=1;i<=n;i++)printf("%lld ",(ans[i]+ans[n-i+1]+1LL*mi[n-2]*n)%mod*inv%mod); 24 }
将$y=x-1$的特殊情况累加后即为$frac{n(n-1)}{2}$,然后对于$k$的枚举改为差分,时间复杂度降为$o(n^{2})$,代码如下
1 for(int i=2;i<=n;i++) 2 for(int j=i+1;j<n;j++){ 3 int k=(j+i)/2; 4 ans[1]=(ans[1]+1LL*j*mi[j-i-1])%mod; 5 ans[k+1]=(ans[k+1]-1LL*(j-i+1)*mi[j-i-1]%mod+mod)%mod; 6 } 7 for(int i=2;i<=n;i++)ans[i]=(ans[i]+ans[i-1])%mod; 8 int inv=1,s=1LL*n*(n-1)/2%mod; 9 for(int i=1;i<=n;i++){ 10 inv=1LL*(mod+1)/2*inv%mod; 11 ans[i]=(ans[i]+s)%mod; 12 }
进一步的,对于$ans[1]$的修改比较好处理,对于$ans[k+1]$的修改可以枚举$j-i$,那么$k=frac{j-i}{2}+i$,再进行一次差分即可,时间复杂度即降为$o(n)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 200005 4 #define mod 998244353 5 int n,mi[N],ans[N]; 6 int main(){ 7 scanf("%d",&n); 8 if (n==1){ 9 printf("%d",(mod+1)/2); 10 return 0; 11 } 12 mi[0]=1; 13 for(int i=1;i<=n;i++)mi[i]=mi[i-1]*2%mod; 14 for(int i=1;i<=n-3;i++){ 15 ans[i/2+3]=(ans[i/2+3]-1LL*(i+1)*mi[i-1]%mod+mod)%mod; 16 ans[i/2+n-i+1]=(ans[i/2+n-i+1]+1LL*(i+1)*mi[i-1]%mod)%mod; 17 } 18 for(int i=2;i<=n;i++)ans[i]=(ans[i]+ans[i-1])%mod; 19 for(int i=3;i<n;i++)ans[1]=(ans[1]+1LL*i*(mi[i-2]-1))%mod; 20 for(int i=2;i<=n;i++)ans[i]=(ans[i]+ans[i-1])%mod; 21 int inv=1,s=1LL*n*(n-1)/2%mod; 22 for(int i=1;i<=n;i++){ 23 inv=1LL*(mod+1)/2*inv%mod; 24 ans[i]=(ans[i]+s)%mod; 25 } 26 for(int i=1;i<=n;i++)printf("%lld ",(ans[i]+ans[n-i+1]+1LL*mi[n-2]*n)%mod*inv%mod); 27 }