Description
给定一个⻓为 n 的字符串 s , 问有多少个⻓为 m 的字符串 t 满足:
将 t 无限重复后,可以从中截出一个⻓度为 n 且字典序比 s 小的串。
m ≤ 2000 n ≤ 2000
Solution
正难则反,补集转换,用 (26^m) 减去“无法从中截出字典序比 s 小的串”的方案数。
方便表述,称字符串t具有特征 (A) 当且仅当无法从无限重复的t中截出一段长度为m且字典序比s小的字段即A为任意无限重复的t中长度为m的字典序都比s大。
考虑构造一个有限状态自动机能接受所有满足特征A的串,然后在上面计数,那么我们要统计对于每个节点开头走m条边后回到它自己的方案数(t串是无限长的)。
由于需要满足特征A,所以一个点的出边只有最大的边是有用的,因为满足A的字符串一定不会走更小的边,(要么比s大,要么目前和s一样,比s大对应的是已经接受了一个满足A的串,直接跳到根,和s一样说明要继续走下去)。
于是这就是一个只保留最大转移边的kmp自动机。
并且一个节点只有一条出边,还有许多边指向根,后者之间本质是一样的我们只要记个数即可(代码实现中是edge[i],表示i点指向根的边数)。
现在考虑如何在上面dp,不难发现这个图很特殊是一个rho,图上的路径只有两种:
- 在环上走m步回到自己,只有当环的大小为m的约数时存在。
- 从自己走若步(比如j步)到根,再从根走m-j步回到自己。
前者直接找环算,后者设 (f[i][u]) 表示从根走i步到u的方案数, (g[i][u]) 为从u走i步到根的方案数,dp出来后枚举j即可。
[f[i + 1][v] leftarrow f[i][u] \
f[i + 1][0] leftarrow f[i][u] imes edge[u]\
g[i + 1][u] leftarrow g[i][v]
]
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <fstream>
typedef long long LL;
typedef unsigned long long uLL;
#define SZ(x) ((int)x.size())
#define ALL(x) (x).begin(), (x).end()
#define MP(x, y) std::make_pair(x, y)
#define DE(x) cerr << x << endl;
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define GO cerr << "GO" << endl;
using namespace std;
inline void proc_status()
{
ifstream t("/proc/self/status");
cerr << string(istreambuf_iterator<char>(t), istreambuf_iterator<char>()) << endl;
}
inline int read()
{
register int x = 0; register int f = 1; register char c;
while (!isdigit(c = getchar())) if (c == '-') f = -1;
while (x = (x << 1) + (x << 3) + (c xor 48), isdigit(c = getchar()));
return x * f;
}
template<class T> inline void write(T x)
{
static char stk[30]; static int top = 0;
if (x < 0) { x = -x, putchar('-'); }
while (stk[++top] = x % 10 xor 48, x /= 10, x);
while (putchar(stk[top--]), top);
}
template<typename T> inline bool chkmin(T &a, T b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return a < b ? a = b, 1 : 0; }
const int maxN = 2e3;
const int mod = 998244353;
namespace math
{
void pls(int &x, int y)
{
x += y;
if (x >= mod) x -= mod;
if (x < 0) x += mod;
}
LL qpow(LL a, LL b)
{
LL ans(1);
while (b)
{
if (b & 1)
ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
}
using math::pls;
using math::qpow;
int n, m; //n字符串长度,m走m步
char str[maxN + 2];
int fail[maxN + 2], ver[maxN + 2], edge[maxN + 2];
void insert()
{
fail[1] = 0;
for (int i = 2, j = 0; i <= n; ++i)
{
while (j and str[j + 1] != str[i]) j = fail[j];
j += str[j + 1] == str[i];
fail[i] = j;
}
}
void build()
{
for (int i = 0; i <= n; ++i)
{
for (int j = 25; j >= 0; --j)
{
int p = i;
if (p == n) p = fail[p];
while (p and str[p + 1] != j + 'a') p = fail[p];
p += (str[p + 1] == (j + 'a'));
if (p)
{
ver[i] = p;
edge[i] = 25 - j;
break;
}
}
}
}
int size;
int f[maxN + 2][maxN + 2], g[maxN + 2][maxN + 2]; // f[i][u] : root -> u cost i ; g[i][u] : u -> root cost i
void DP()
{
f[0][0] = 1;
for (int i = 0; i < m; ++i)
for (int j = 0; j <= n; ++j)
{
pls(f[i + 1][ver[j]], f[i][j]);
pls(f[i + 1][0], 1ll * f[i][j] * edge[j] % mod);
}
for (int i = 0; i <= n; ++i)
g[1][i] = edge[i];
for (int i = 2; i <= m; ++i)
for (int j = 0; j <= n; ++j)
g[i][j] = g[i - 1][ver[j]];
}
int key;
bool vis[maxN + 2];
bool dfs(int u)
{
if (!u) return 0;
if (vis[u]) { key = u; return 1; }
vis[u] = 1;
if (dfs(ver[u])) { size++; return key != u; }
return 0;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("xhc.in", "r", stdin);
freopen("xhc.out", "w", stdout);
#endif
scanf("%d %s", &m, str + 1);
n = strlen(str + 1);
insert();
build();
DP();
int ans = 0;
dfs(1);
if (m % size == 0)
ans = size;
for (int i = 0; i <= n; ++i)
{
int sum = 0;
for (int j = 0; j <= m; ++j)
pls(sum, 1ll * f[j][i] * g[m - j][i] % mod);
pls(ans, sum);
}
cout << ((qpow(26, m) - ans) % mod + mod) % mod << endl;
return 0;
}