正题
题目链接:https://www.ybtoj.com.cn/problem/526
题目大意
一个\(n\times m\)的网格上有字母,你每次可以沿平行坐标轴对折网格,要求对折的对应位置字母相同。
询问有多少个可能对折出来的子矩阵。
\(1\leq n\times m\leq 10^6\)
解题思路
首先行和列是独立的,行的对折不会和列的对折有任何关联,所以可以分开考虑行和列可以对折出的区间。
然后设每一行分开对每个轴求出一个最大对折距离(这个用二分+\(hash\)或者马拉车就可以求出来了),然后同位置的所有行取最小值就好了。
之后对于每个轴的位置就有一个可以转移过来的区间,而且左右的对折如果过头了不会影响答案(可以自己画个图,因为回文串的性质,那么两边一定可以先对折出一个更小不会冲突的区间)
维护一个前缀和就好了(考场上犯病写了个树状数组)
时间复杂度\(O(n\log n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define ull unsigned long long
#define lowbit(x) (x&-x)
using namespace std;
const ll N=1e6+10;
const ull g=131;
ll n,m,t[N],ac[N],cr[N],dp[N],lim;
ull h[N],f[N],pw[N];
char c[N],*s[N];
void Change(ll x,ll val){
while(x<=lim){
t[x]+=val;
x+=lowbit(x);
}
return;
}
ll Ask(ll x){
ll ans=0;
while(x){
ans+=t[x];
x-=lowbit(x);
}
return ans;
}
ll Query(ll l,ll r)
{return Ask(r)-Ask(l-1);}
ull geth(ll l,ll r)
{return h[r]-h[l-1]*pw[r-l+1];}
ull getf(ll l,ll r)
{return f[l]-f[r+1]*pw[r-l+1];}
signed main()
{
freopen("paper.in","r",stdin);
freopen("paper.out","w",stdout);
scanf("%lld%lld",&n,&m);pw[0]=1;
for(ll i=1;i<=max(n,m);i++)
pw[i]=pw[i-1]*g;
memset(ac,0x3f,sizeof(ac));
memset(cr,0x3f,sizeof(cr));
s[1]=c-1;
for(ll p=1;p<=n;p++){
scanf("%s",s[p]+1);
for(ll i=1;i<=m;i++)
h[i]=h[i-1]*g+s[p][i]-'a';
for(ll i=m;i>=1;i--)
f[i]=f[i+1]*g+s[p][i]-'a';
for(ll i=2;i<=m;i++){
ll l=0,r=min(i-2,m-i);
while(l<=r){
ll mid=(l+r)>>1;
if(geth(i-mid-1,i-1)==getf(i,i+mid))l=mid+1;
else r=mid-1;
}
ac[i]=min(ac[i],r);
}
s[p+1]=s[p]+m;
}
f[n+1]=0;
for(ll p=1;p<=m;p++){
for(ll i=1;i<=n;i++)
h[i]=h[i-1]*g+s[i][p]-'a';
for(ll i=n;i>=1;i--)
f[i]=f[i+1]*g+s[i][p]-'a';
for(ll i=2;i<=n;i++){
ll l=0,r=min(i-2,n-i);
while(l<=r){
ll mid=(l+r)>>1;
if(geth(i-mid-1,i-1)==getf(i,i+mid))l=mid+1;
else r=mid-1;
}
cr[i]=min(cr[i],r);
}
}
lim=m;Change(1,1);dp[1]=1;
for(ll i=2;i<=m;i++){
bool tmp=(Query(i-ac[i]-1,i-1)!=0);
dp[i]=dp[i-1]+tmp;
if(tmp)Change(i,1);
}
memset(t,0,sizeof(t));
Change(m,1);ll sum=dp[m];
for(ll i=m-1;i>=1;i--){
bool tmp=(Query(i+1,i+ac[i+1]+1)!=0);
if(tmp)sum+=dp[i],Change(i,1);
}
memset(t,0,sizeof(t));
lim=n;Change(1,1);
for(ll i=2;i<=n;i++){
bool tmp=(Query(i-cr[i]-1,i-1)!=0);
dp[i]=dp[i-1]+tmp;
if(tmp)Change(i,1);
}
memset(t,0,sizeof(t));
Change(n,1);ll ans=dp[n];
for(ll i=n-1;i>=1;i--){
bool tmp=(Query(i+1,i+cr[i+1]+1)!=0);
if(tmp)ans+=dp[i],Change(i,1);
}
printf("%lld\n",ans*sum);
return 0;
}