题目
解法
相交的回文子串是不好维护的,如果直接用数据结构将某段区间 \(+1\) 表示有回文子串,不仅会 \(\mathtt{T}\) 而且会算重。不妨计算不相交的回文子串对数,再用所有回文串对数减去即可。
设 \(r_i\) 是以 \(i\) 为结尾的回文子串个数,\(l_i\) 是以 \(j\) 为起始(\(j\ge i\))的回文子串的个数。回文子串直接用 \(\text{Manacher}\) 求解即可,具体统计 \(l_i,r_i\) 时由于对于 \(i\),右端点属于 \([i,i+p_i-1]\) 都是满足条件的,所以需要差分。
不相交的回文子串对数很容易被表示成 \(\sum_{i=0}^{n-2}r_i\cdot l_{i+1}\)。
代码
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
using namespace std;
const int N = 2e6 + 5, mod = 51123987, inv2 = 25561994;
long long ans;
int n, p[N << 1], s[N << 1], ori[N], l[N << 1], r[N << 1], lef[N], rig[N];
int read() {
int x = 0, f = 1; char s;
while((s = getchar()) > '9' || s < '0') {
if(s == '-') f = -1;
if(s == EOF) exit(0);
}
while(s <= '9' && s >= '0') {
x = (x << 1) + (x << 3) + (s ^ 48);
s = getchar();
}
return x * f;
}
int init() {
s[0] = -1; s[1] = -3;
int j = 2;
for(int i = 0; i < n; ++ i) {
s[j ++] = ori[i];
s[j ++] = -3;
}
s[j] = -2;
return j;
}
void manacher() {
n = init();
int ans = -1, R = 0, mid;
for(int i = 1; i < n; ++ i) {
if(i < R) p[i] = min(p[(mid << 1) - i], R - i);
else p[i] = 1;
while(s[i - p[i]] == s[i + p[i]]) ++ p[i];
if(R < i + p[i]) {
R = i + p[i];
mid = i;
}
}
}
long long fix(const long long x) {return (x % mod + mod) % mod;}
int main() {
char ch[N];
int tmp;
n = read(); scanf("%s", ch);
for(int i = 0; i < n; ++ i) ori[i] = ch[i] - 'a';
tmp = n;
manacher();
for(int i = 1; i < n; ++ i) {
++ l[i - p[i] + 1]; -- l[i + 1];
++ r[i]; -- r[i + p[i]];
}
for(int i = 1; i < n; ++ i) {
(ans += p[i] >> 1) %= mod; // count the number of palindrome string
(l[i] += l[i - 1]) %= mod; (r[i] += r[i - 1]) %= mod;
if(s[i] != -3) lef[i - 1 >> 1] = l[i], rig[i - 1 >> 1] = r[i];
}
for(int i = tmp - 2; i >= 0; -- i) (lef[i] += lef[i + 1]) %= mod;
ans = ans * fix(ans - 1) % mod * inv2 % mod;
for(int i = 0; i < tmp; ++ i)
ans = fix(ans - 1ll * rig[i] * lef[i + 1] % mod);
printf("%lld\n", ans);
return 0;
}