发现自己一年之前非常垃圾
题目大意是给你一个(n)个点的环,给每个点一个([1,a_i])的取值,并且满足环上任意相连两点权值不能相等,求方案数
考虑断环为链,发现不大会
不妨考虑所有(a_i)均相等的情况,设(m=a_i)
对于第一个点,有(m)种选择,其后每一个点的取值都不能和上一个相等,即(m-1)种选择,于是整个环就是(m(m-1)^{n-1})
吗?
显然不是,这样我们不能保证(1)号点和(n)号点的取值不相等。设(f_i)表示(1)号点恰好和(n-i+1)到(n)号点取值相等的情况,我们算的(m(m-1)^{n-1})其实等于(f_0+f_1)
考虑如何消掉(f_1),我们可以强行将(n)和(1)取值相同,其余点还是不能和前一个点取值相等,方案数是(m(m-1)^{n-2}=f_1+f_2);更一般的(m(m-1)^{n-i}=f_{i-1}+f_i),但是有一个特殊情况,即(m(m-1)=f_{n-2}),即让(3)号点到(n)号点都和(1)号点取值相等,这样(2)号点和(1)不相同自然就不会和后面的点相同。
我们要求的是(f_0),我们发现(f_0+f_1-(f_1+f_2)+f_2+f_3-....-f_{n-2}+f_{n-2}=f_0),即我们配一个(-1)的容斥系数即能求出(f_0)。
于是我们利用这个容斥就能断环为链,所以我们来考虑更一般的链上问题,即(a_i)不同的情况。
有一个显然的暴力(dp),设(dp_{i,j})表示第(i)个点取值为(j)的方案数,(s_i=sum_{j=1}^{a_i}dp_{i,j}),转移显然有(dp_{i,j}=s_{i-1}-dp_{i-1,j})
由于我们的容斥本质上是使得最后连续的一段和(1)取值相等,这一段连续的取值受限于这一段中(a_i)的最小值,于是我们不妨选一个(a_i)最小的点作为一号点,这样每次(dp)的初值就不会改变,只需做一次(dp)即可。
考虑优化这个(dp)
当(a_i>a_{i-1})的时候,对于(forall jin [1,a_{i-1}])有(dp_{i,j}=s_{i-1}-dp_{i-1,j});当(j>a_{i-1})的时候,由于(dp_{i-1,j}=0),所以对于(jin(a_{i-1},a_i])有(dp_{i,j}=s_{i-1})
当(a_{i}<a_{i-1})的时候,对于(forall jin [1,a_{i}])有(dp_{i,j}=s_{i-1}-dp_{i-1,j});当(j>a_i)的时候,则有(dp_{i,j}=0)
不难发现这几个转移我们只需要一个能够支持区间取反、区间加以及区间覆盖的数据结构就能维护,于是直接使用线段树来整体dp即可。
由于(a_i)比较大,所以得动态开点,复杂度是(O(nlog a_i))只能有(80pts)
代码
#include<bits/stdc++.h>
#define re register
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#pragma GCC optimize(3)
#pragma GCC optimize("-fcse-skip-blocks")
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int mod=1e9+7;
const int M=7e7+5;
const int maxn=1e6+5;
int n,a[maxn],f[maxn],b[maxn],pos,mx,rt;
inline int qm(int x) {return x>=mod?x-mod:x;}
int l[M],r[M],gt[M],jt[M],sum[M],cnt;
bool ft[M];
inline void pushdown(int now,int lx,int ry) {
int mid=lx+ry>>1;
int lenl=mid-lx+1,lenr=ry-mid;
if(gt[now]!=-1) {
if(!l[now]) l[now]=++cnt;
if(!r[now]) r[now]=++cnt;
gt[l[now]]=gt[r[now]]=gt[now];
sum[l[now]]=1ll*gt[now]*lenl%mod;
sum[r[now]]=1ll*gt[now]*lenr%mod;
ft[l[now]]=ft[r[now]]=jt[l[now]]=jt[r[now]]=0;
gt[now]=-1;
}
if(ft[now]) {
if(!l[now]) l[now]=++cnt;
if(!r[now]) r[now]=++cnt;
ft[l[now]]^=1;ft[r[now]]^=1;
sum[l[now]]=qm(mod-sum[l[now]]);
sum[r[now]]=qm(mod-sum[r[now]]);
jt[l[now]]=qm(mod-jt[l[now]]);
jt[r[now]]=qm(mod-jt[r[now]]);
ft[now]=0;
}
if(jt[now]) {
if(!l[now]) l[now]=++cnt;
if(!r[now]) r[now]=++cnt;
sum[l[now]]=qm(sum[l[now]]+1ll*jt[now]*lenl%mod);
sum[r[now]]=qm(sum[r[now]]+1ll*jt[now]*lenr%mod);
jt[l[now]]=qm(jt[l[now]]+jt[now]);
jt[r[now]]=qm(jt[r[now]]+jt[now]);
jt[now]=0;
}
}
int gan(int now,int x,int y,int lx,int ry,int v) {
if(!now) now=++cnt,gt[now]=-1;
if(x<=lx&&y>=ry) {
sum[now]=1ll*v*(ry-lx+1)%mod;
gt[now]=v;ft[now]=0;jt[now]=0;
return now;
}
pushdown(now,lx,ry);
int mid=lx+ry>>1;
if(x<=mid) l[now]=gan(l[now],x,y,lx,mid,v);
if(y>mid) r[now]=gan(r[now],x,y,mid+1,ry,v);
sum[now]=qm(sum[l[now]]+sum[r[now]]);
return now;
}
int jia(int now,int x,int y,int lx,int ry,int v) {
if(!now) now=++cnt,gt[now]=-1;
if(x<=lx&&y>=ry) {
sum[now]=qm(sum[now]+1ll*v*(ry-lx+1)%mod);
jt[now]=qm(jt[now]+v);
return now;
}
pushdown(now,lx,ry);
int mid=lx+ry>>1;
if(x<=mid) l[now]=jia(l[now],x,y,lx,mid,v);
if(y>mid) r[now]=jia(r[now],x,y,mid+1,ry,v);
sum[now]=qm(sum[l[now]]+sum[r[now]]);
return now;
}
int qufan(int now,int x,int y,int lx,int ry) {
if(!now) now=++cnt,gt[now]=-1;
if(x<=lx&&y>=ry) {
sum[now]=qm(mod-sum[now]);
ft[now]^=1;jt[now]=qm(mod-jt[now]);
return now;
}
pushdown(now,lx,ry);
int mid=lx+ry>>1;
if(x<=mid) l[now]=qufan(l[now],x,y,lx,mid);
if(y>mid) r[now]=qufan(r[now],x,y,mid+1,ry);
sum[now]=qm(sum[l[now]]+sum[r[now]]);
return now;
}
int main() {
n=read();
for(re int i=1;i<=n;i++) a[i]=read();
pos=1;for(re int i=2;i<=n;i++) if(a[i]<a[pos]) pos=i;
for(re int i=1;i<=n;i++) {
b[i]=a[pos++];
if(pos>n) pos=1;
}
for(re int i=1;i<=n;i++) mx=max(mx,a[i]);
f[1]=b[1];rt=gan(rt,1,b[1],1,mx,1);
for(re int i=2;i<=n;++i) {
if(b[i-1]<b[i]) {
rt=qufan(rt,1,b[i-1],1,mx);
rt=jia(rt,1,b[i],1,mx,f[i-1]);
}
else {
rt=qufan(rt,1,b[i],1,mx);
rt=jia(rt,1,b[i],1,mx,f[i-1]);
if(b[i]+1<=b[i-1]) rt=gan(rt,b[i]+1,b[i-1],1,mx,0);
}
f[i]=sum[rt];
}
int ans=0;
for(re int i=n;i>=2;--i)
if((n-i+1)&1) ans=qm(ans+f[i]);
else ans=qm(ans-f[i]+mod);
printf("%d
",ans);
return 0;
}
正解也就是容斥+dp,但是容斥方法好像不太一样,就直接丢链跑了