如果(r-l)比较小,可以将所有满足条件的串扔进(AC)自动机然后在上面DP,从前往后确定字符串的每一位。
但是(l,r leq 10^{800})就十分不可行,所以需要优化这个算法。
考虑可能会有某一个节点的子节点连向的所有子节点构成一个满十叉树,意即当到达了这个节点之后可以随便往下走子节点,而且不论怎么走得到的结果都相同。比如说对于某一个九位数(x = overline {x_1x_2x_3x_4x_5x_6x_7x_8x_9}),当前的(l=123456789,r = 987654321),那么若(x_1 in [2,8]),无论(x_2)到(x_9)如何取值,(x)一定满足(l leq x leq r)。而当(x_1 = 1,x2 in [3,9])时都可以满足(l leq x leq r),无论(x_3)到(x_9)如何取值。
我们不妨把长度为(k)的“无论如何取值都能产生一个合法情况”的情况叫做产生一个(k)位通配符,比如说上面举的例子中(x_1in[2,8])时就会产生一个(8)位通配符,而(x_1 = 1,x2 in [3,9])则产生了一个(7)位通配符。特殊地,如果某个串刚好与(l)或(r)相等,我们认为产生了一个(0)位通配符(也就表示答案(+1))。
考虑如何会产生一个若干位的通配符。对于(geq l)的情况,当(forall j x_j = l_j)且(x_{j+1} > l_{j+1}),就会产生一个(|l| - j - 1)位的通配符,(leq r)的情况类似,而如果(|l| < |r|),只要任意选择第一位为(1)到(9)的数,就可以产生(|l| + 1)到(|r| - 1)位的通配符。
可以发现最后产生的串的答案数量就是这个串中出现的通配符数量的总和。那么我们只需要计算出(AC)自动机上到达每一个节点能够产生的通配符数量,然后就可以比较快速地DP了。
将(l,r)两个串以及所有能够出现通配符情况的串放在一起构建Trie图/AC自动机,构建方法类似数位DP,对于每一个(l)和(r)的前缀枚举能够产生通配符的下一个字符来计算(sum_{i,j})表示到达第(i)个节点能够获得的长度为(j)的通配符数量。注意通配符数量是可以通过(fail)指针传递的。具体细节请阅读下面代码中的insert
函数和build
函数
又设(dp_{i,j})表示长度为(i)、当前所在节点为(j)的最大通配符数量,转移:(dp_{i,j} + sumlimits_{k=0}^{N-i-1}sum_{ch_{i , c} , k} ightarrow dp_{i + 1 , ch_{i , c}})
后面的(sum)只需要枚举到(N-i-1)的原因是通配符的结尾不能超出字符串的结尾。
将(sum)前缀和,将复杂度做到(O((|l| + |r|)NA^2)),其中(A)为字符集大小。
输出方案倒推一边即可
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<ctime>
#include<cctype>
#include<algorithm>
#include<cstring>
#include<iomanip>
#include<queue>
#include<map>
#include<set>
#include<bitset>
#include<stack>
#include<vector>
#include<cmath>
#include<random>
//This code is written by Itst
using namespace std;
int N;
struct node{
int ch[10] , sum[2007] , fail;
}Trie[16017];
int cnt = 1 , L1 , L2;
char L[807] , R[807];
#define get(x , y) (!Trie[x].ch[y] ? Trie[x].ch[y] = ++cnt : Trie[x].ch[y])
void insert(){
L1 = strlen(L + 1) , L2 = strlen(R + 1);
int u = 1 , v = 1;
if(L1 == L2){
for(int j = 1 ; j <= L1 ; ++j)
if(u == v){
for(int k = L[j] - '0' + 1 ; k < R[j] - '0' ; ++k)
++Trie[get(u , k)].sum[L1 - j];
u = get(u , L[j] - '0');
v = get(v , R[j] - '0');
}
else{
for(int k = L[j] - '0' + 1 ; k <= 9 ; ++k)
++Trie[get(u , k)].sum[L1 - j];
u = get(u , L[j] - '0');
for(int k = j == 1 ; k < R[j] - '0' ; ++k)
++Trie[get(v , k)].sum[L2 - j];
v = get(v , R[j] - '0');
}
++Trie[u].sum[0];Trie[v].sum[0] += u != v;
}
else{
for(int j = 1 ; j <= L1 ; ++j){
for(int k = L[j] - '0' + 1 ; k <= 9 ; ++k)
++Trie[get(u , k)].sum[L1 - j];
u = get(u , L[j] - '0');
}
for(int j = 1 ; j <= L2 ; ++j){
for(int k = j == 1 ; k < R[j] - '0' ; ++k)
++Trie[get(v , k)].sum[L2 - j];
v = get(v , R[j] - '0');
}
for(int j = L1 + 1 ; j < L2 ; ++j)
for(int k = 1 ; k <= 9 ; ++k)
++Trie[get(1 , k)].sum[j - 1];
++Trie[u].sum[0];++Trie[v].sum[0];
}
}
void build(){
queue < int > q;
for(int i = 0 ; i < 10 ; ++i)
if(!Trie[1].ch[i])
Trie[1].ch[i] = 1;
else{
Trie[Trie[1].ch[i]].fail = 1;
q.push(Trie[1].ch[i]);
}
while(!q.empty()){
int t = q.front();
q.pop();
for(int j = 0 ; j < L2 ; ++j)
Trie[t].sum[j] += Trie[Trie[t].fail].sum[j];
for(int i = 0 ; i < 10 ; ++i)
if(!Trie[t].ch[i])
Trie[t].ch[i] = Trie[Trie[t].fail].ch[i];
else{
Trie[Trie[t].ch[i]].fail = Trie[Trie[t].fail].ch[i];
q.push(Trie[t].ch[i]);
}
}
for(int i = 1 ; i <= cnt ; ++i)
for(int j = 1 ; j < N ; ++j)
Trie[i].sum[j] += Trie[i].sum[j - 1];
}
void init(){
scanf("%s %s %d" , L + 1 , R + 1 , &N);
insert();
build();
}
int dp[2007][16017];
bool can[2007][16017];
inline int maxx(int a , int b){
return a > b ? a : b;
}
int main(){
init();
memset(dp , -0x3f , sizeof(dp));
dp[0][1] = 0;
for(int i = 0 ; i < N ; ++i)
for(int j = 1 ; j <= cnt ; ++j)
if(dp[i][j] >= 0)
for(int k = 0 ; k < 10 ; ++k)
dp[i + 1][Trie[j].ch[k]] = maxx(dp[i + 1][Trie[j].ch[k]] , dp[i][j] + Trie[Trie[j].ch[k]].sum[N - i - 1]);
int ans = 0;
for(int i = 1 ; i <= cnt ; ++i)
ans = maxx(ans , dp[N][i]);
cout << ans << endl;
for(int i = 1 ; i <= cnt ; ++i)
can[N][i] = (dp[N][i] == ans);
for(int i = N - 1 ; i >= 0 ; --i)
for(int j = 1 ; j <= cnt ; ++j)
if(dp[i][j] >= 0)
for(int k = 0 ; !can[i][j] && k < 10 ; ++k)
can[i][j] = can[i + 1][Trie[j].ch[k]] && (dp[i + 1][Trie[j].ch[k]] == dp[i][j] + Trie[Trie[j].ch[k]].sum[N - i - 1]);
int u = 1;
for(int i = 1 ; i <= N ; ++i)
for(int j = 0 ; j < 10 ; ++j)
if(can[i][Trie[u].ch[j]] && (dp[i][Trie[u].ch[j]] == dp[i - 1][u] + Trie[Trie[u].ch[j]].sum[N - i])){
putchar(j + '0');
u = get(u , j);
break;
}
return 0;
}