题意

思路
官方题解

#include <bits/stdc++.h>
#define LL long long
#define ULL unsigned long long
#define UI unsigned int
#define mem(i, j) memset(i, j, sizeof(i))
#define rep(i, j, k) for(int i = j; i <= k; i++)
#define dep(i, j, k) for(int i = k; i >= j; i--)
#define pb push_back
#define make make_pair
#define INF 0x3f3f3f3f
#define inf LLONG_MAX
#define PI acos(-1)
#define fir first
#define sec second
#define lb(x) ((x) & (-(x)))
#define dbg(x) cout<<#x<<" = "<<x<<endl;
using namespace std;
const int N = 1e6 + 5;
const LL mod = 998244353;
const ULL base = 101;
ULL pw_base[N], hs[N];
LL pw_ten[N], pre[N];
char a[N];
int id[N];
ULL get(int l, int r) { /// 获得区间 [l,r] 的哈希值
return hs[r] - hs[l - 1] * pw_base[r - l + 1];
}
ULL get1(int st, int len, int pos) { /// 获得 x + A 字符串的前面 pos 个字符的哈希值
if(pos <= len) return get(st + 1, st + pos);
else return get(st + 1, st + len) * pw_base[pos - len] + get(1, pos - len);
}
bool cmp(int x, int y) {
int mi = min(x, y); int ma = max(x, y); int dis = ma - mi;
int l = 1, r = ma, ans = 0, mid = 0;
while(l <= r) {
mid = (l + r) >> 1;
if(get1(mi, dis, mid) != hs[mid]) {
ans = mid; r = mid - 1;
}
else l = mid + 1;
}
char ax, bx;
ax = a[ans];
if(ans <= dis) bx = a[mi + ans];
else bx = a[ans - dis];
if(ax > bx) return x < y;
else return x >= y;
}
void solve() {
scanf("%s", a + 1);
int n = strlen(a + 1);
pw_ten[0] = 1; pw_base[0] = 1;
rep(i, 1, n) {
id[i] = i;
pw_ten[i] = pw_ten[i - 1] * 10 % mod;
pre[i] = (pre[i - 1] * 10 % mod + (a[i] - '0')) % mod;
pw_base[i] = pw_base[i - 1] * base;
hs[i] = hs[i - 1] * base + a[i];
}
sort(id + 1, id + 1 + n, cmp);
LL ans = 0LL;
rep(i, 1, n) {
ans = (ans * pw_ten[id[i]] % mod + pre[id[i]]) % mod;
}
printf("%lld
", ans);
}
int main() {
// int _; scanf("%d", &_);
// while(_--) solve();
solve();
return 0;
}