几个坑点(我都犯过)
(1. work1,work2,dfs1,dfs2)函数别写串了,这个还是比较好调出来的。
(2.dp)数组每次都初始化为(-1)(这个我没范)
(3.)处理前导零的时候当前位置和上一位都被代入了,但一开始上上位需要赋一个(!in[0,9])的数,我赋成(-1)忘记特判导致(dp)数组越界返回了错误答案,所以尽量赋成(10)吧
更多细节见代码
(Code)
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#define ll long long
#define re register
#define maxn 1010
using namespace std;
const ll mod=1e9+7;
inline int read()
{
int x=0,f=1; char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
char AA[maxn],BB[maxn];
int A[maxn],B[maxn],lena,lenb;
ll dp[1010][10][2][2][11],ans1,ans2,ans;
ll dfs1(int p,int a,int b,int c,int d)
{
//a表示上一位,b表示是否小于边界,c表示 是否构成回文串
if(p<=0) return c;
if(~dp[p][a][b][c][d]) return dp[p][a][b][c][d]%mod;
int lim=b?9:A[p];
ll res=0;
for(int i=0;i<=lim;++i)
res=(res+dfs1(p-1,i,(i<lim)||b,(i==a)||c||(i==d),a))%mod;
return dp[p][a][b][c][d]=res;
}
ll dfs2(int p,int a,int b,int c,int d)
{
//a表示上一位,b表示是否小于边界,c表示 是否构成回文串,d表示上上位
if(p<=0) return c;//如果d=-1,会访问-1下标,容易混乱越界返回错误值
if(~dp[p][a][b][c][d]) return dp[p][a][b][c][d]%mod;
int lim=b?9:B[p];
ll res=0;
for(int i=0;i<=lim;++i)
res=(res+dfs2(p-1,i,(i<lim)||b,(i==a)||c||(i==d),a))%mod;
return dp[p][a][b][c][d]=res;//回溯别忘记赋值
}
void work1()//处理较小的数
{
memset(dp,-1,sizeof(dp));//每次都重新赋值-1
for(re int i=1;i<=lena-1;++i)
for(re int j=1;j<=9;++j)//处理位数小的情况,这些位置从[1,9]取
ans1=(ans1+dfs1(i-1,j,1,0,10))%mod;
for(re int i=1;i<=A[lena];++i)
ans1=(ans1+dfs1(lena-1,i,i<A[lena],0,10))%mod;
}
void work2()//处理较大的数,注意这些函数别写混了
{
memset(dp,-1,sizeof(dp));
for(re int i=1;i<=lenb-1;++i)
for(re int j=1;j<=9;++j)
ans2=(ans2+dfs2(i-1,j,1,0,10))%mod;
for(re int i=1;i<=B[lenb];++i)
ans2=(ans2+dfs2(lenb-1,i,i<B[lenb],0,10))%mod;
}
int main()
{
cin>>AA+1>>BB+1;
lena=strlen(AA+1),lenb=strlen(BB+1);
for(re int i=1;i<=lena;++i) A[i]=AA[lena-i+1]-'0';
for(re int i=1;i<=lenb;++i) B[i]=BB[lenb-i+1]-'0';//反向存,高位下标大
A[1]--;//边界需要高精度减一下
for(re int i=1;i<=lena;++i)
{
if(A[i]<0) A[i]+=10,A[i+1]--;
else break;
}
work2();
work1();
ans=((ans2-ans1)%mod+mod)%mod;//防止爆负数,(x%mod+mod)%mod
printf("%lld
",ans);
return 0;
}