sol
先预处理从每个点出发向上/下/左/右能延伸多长。
考虑怎么计算答案。我们只要枚举中轴线,再枚举上方的十字交点,枚举下方的十字交点,然后算答案即可。
考虑一个左右宽的最小值为(L)的水平线段对下方的影响。对于下方宽度(in[2,L])的线段,相当于加上一个等差数列,而对于宽度(>L)的线段,相当于加上一个定值(L-1)。
所以我们现在要做的就是:动态支持区间加等差数列,区间求和。
用树状数组维护的话就需要维护二阶差分。设需要维护的数列是(a_i),他的一阶差分是(b_i),树状数组维护的二阶差分是(c_i),有:
[sum_{i=1}^xa_i=sum_{i=1}^{x}sum_{j=1}^ib_j\=sum_{j=1}^x(x-j+1)b_j\=sum_{j=1}^x(x-j+1)sum_{k=1}^jc_k\=sum_{k=1}^xc_ksum_{j=k}^x(x-j+1)\=sum_{k=1}^xc_kfrac 12(x-k+2)(x-k+1)\=frac 12[sum_{k=1}^xc_kk^2-(2x+3)sum_{k=1}^xc_kk+(x^2-3x+2)sum_{k=1}^xc_k]
]
所以开三个树状数组维护(sum_kc_k,sum_kc_kk,sum_kc_kk^2)的前缀和即可。
复杂度(O(RClog n)),由于暴力清空了树状数组所以复杂度貌似还要带个(O(Cn))。
当然你要是精细一点的清空是可以做到把这个复杂度去掉的,只是写起来就麻烦一点。
code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi(){
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 2e6+5;
const int mod = 1e9+9;
const int inv2 = 5e8+5;
int n,m,k,vis[N],u[N],d[N],l[N],r[N],h[N],c1[N],c2[N],c3[N],ans;
inline int p(int x,int y){if(!x||!y||x>n||y>m)return 0;return (x-1)*m+y;}
inline void add(int &x,int y){x+=y;if(x>=mod)x-=mod;}
void init(){for (int i=1;i<=m;++i) c1[i]=c2[i]=c3[i]=0;}
void modify(int x,int v){
for (int i=x;i<=m;i+=i&-i){
add(c1[i],v);
add(c2[i],1ll*x*v%mod);
add(c3[i],1ll*x*x%mod*v%mod);
}
}
int query(int x){
int s1=0,s2=0,s3=0,res=0;
for (int i=x;i;i-=i&-i)
add(s1,c1[i]),add(s2,c2[i]),add(s3,c3[i]);
add(res,1ll*(1ll*x*x+3*x+2)%mod*s1%mod);
add(res,mod-1ll*(x+x+3)*s2%mod);add(res,s3);
res=1ll*res*inv2%mod;return res;
}
int main(){
n=gi();m=gi();k=gi();
for (int i=1,x,y;i<=k;++i) x=gi(),y=gi(),vis[p(x,y)]=1;
for (int i=1;i<=n;++i)
for (int j=1;j<=m;++j)
if (!vis[p(i,j)]) u[p(i,j)]=u[p(i-1,j)]+1,l[p(i,j)]=l[p(i,j-1)]+1;
for (int i=n;i;--i)
for (int j=m;j;--j)
if (!vis[p(i,j)]) d[p(i,j)]=d[p(i+1,j)]+1,r[p(i,j)]=r[p(i,j+1)]+1;
for (int i=1,id;i<=n;++i)
for (int j=1;j<=m;++j)
if (!vis[p(i,j)]){
id=p(i,j);
h[id]=min(l[id],r[id])-1;
--d[id];--u[id];
}
for (int j=2;j<m;++j,init())
for (int i=3;i<n;++i){
int id=p(i,j);
if (vis[id]) {init();continue;}
if (h[id]) add(ans,1ll*query(h[id]-1)*d[id]%mod);
modify(1,u[id-m]);modify(h[id-m]+1,mod-u[id-m]);
}
printf("%d
",ans);return 0;
}