这是一个模板,如果明白插头 $dp$ 的原理,大力分类讨论就完事了
注意一些细节,比如统计答案时不一定是在 $n,m$ ,因为 $n,m$ 可能不能放
自认为自己的代码比较好看...
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; typedef long long ll; inline int read() { int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); } while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*f; } const int N=2e5,M=2e6; int n,m,a[17][17],bit[27],cur,pre,lstx,lsty; ll f[2][M],Ans; int fir[N+7],from[M+7],to[2][M+7],cntt[2]; inline void insert(int sta,ll val) { int p=sta%N; for(int i=fir[p];i;i=from[i]) if(to[cur][i]==sta) { f[cur][i]+=val; return; } from[++cntt[cur]]=fir[p]; fir[p]=cntt[cur]; to[cur][cntt[cur]]=sta; f[cur][cntt[cur]]=val; } void DP() { f[cur][1]=cntt[cur]=1; int now,right,down; ll val; for(int i=1;i<=n;i++) { for(int j=1;j<=cntt[cur];j++) to[cur][j]<<=2; for(int j=1;j<=m;j++) { pre=cur; cur^=1; cntt[cur]=0; memset(fir,0,sizeof(fir)); for(int k=1;k<=cntt[pre];k++) { now=to[pre][k]; val=f[pre][k]; right=(now>>bit[j-1])%4; down=(now>>bit[j])%4; if(a[i][j]) { if(!right&&!down) insert(now,val); continue; } if(!right&&!down) if(j!=m) insert(now+(1<<bit[j-1])+((1<<bit[j])<<1),val); if(right&&!down) { insert(now,val); if(j!=m) insert(now-right*(1<<bit[j-1])+right*(1<<bit[j]),val); } if(!right&&down) { insert(now-down*(1<<bit[j])+down*(1<<bit[j-1]),val); if(j!=m) insert(now,val); } if(right==1&&down==1) { int cnt=1; for(int l=j+1;l<=m;l++) { if((now>>bit[l])%4==1) cnt++; if((now>>bit[l])%4==2) cnt--; if(!cnt) { insert(now-(1<<bit[j-1])-(1<<bit[j])-(1<<bit[l]),val); break; } } } if(right==2&&down==2) { int cnt=1; for(int l=j-2;l>=0;l--) { if((now>>bit[l])%4==1) cnt--; if((now>>bit[l])%4==2) cnt++; if(!cnt) { insert(now-((1<<bit[j-1])<<1)-((1<<bit[j])<<1)+(1<<bit[l]),val); break; } } } if(right==2&&down==1) insert(now-((1<<bit[j-1])<<1)-(1<<bit[j]),val); if(right==1&&down==2) if(now==(1<<bit[j-1])+((1<<bit[j])<<1)) if(i==lstx&&j==lsty) Ans+=val; } } } } int main() { for(int i=1;i<=26;i++) bit[i]=(i<<1); n=read(),m=read(); char s[17]; for(int i=1;i<=n;i++) { scanf("%s",s+1); for(int j=1;j<=m;j++) a[i][j]=(s[j]=='*'); } for(lstx=n;lstx;lstx--) { bool flag=0; for(lsty=m;lsty;lsty--) if(!a[lstx][lsty]) { flag=1; break; } if(flag) break; } DP(); printf("%lld ",Ans); return 0; }