题面
题解
先说下思路, 要拼接成回文, j 拼接到 i 的后面
- len(i) >= len(j) 则 j的反串 完全匹配 i, 且 i 剩下的 子串是回文
- len(j) > len(i) 则 i 完全匹配 j的反串, 且 j反串 剩下的 子串是回文
把每个串的反串扔到trie中, 每个节点存 当前在这里结尾的字串的数量, 和当前节点下面有多少个回文串
求匹配 直接用kmp
然而你过不了, 这题异常卡常, 只好把 每个串的 extend (KMP扩展数组) 存下来,
也不行! 还要 string 变 char
代码
const int N = 2e6 + 5;
struct node {
static const int N = 2e6 + 5, M = 26;
int tr[N][M], cnt[N], ended[N], tot;
void insert(char* s, int* ext, int len) {
int p = 0;
for (int i = 0; s[i]; ++i) {
int ch = s[i] - 'a';
if (!tr[p][ch]) tr[p][ch] = ++tot;
p = tr[p][ch];
if (i + 1 < len && ext[i + 2] + i + 1 == len) ++cnt[p];
}
++ended[p];
}
int query(char* s, int* ext, int len) {
int p = 0; ll ans = 0;
for (int i = 0; s[i]; ++i) {
int ch = s[i] - 'a';
if (!tr[p][ch]) return ans;
p = tr[p][ch];
if (ext[i + 2] + i + 1 == len) ans += ended[p];
}
return cnt[p] + ans;
}
} tr;
int n, m, _, k;
int f[N << 1], exta[N << 1], a[N], extb[N << 1];
char s[N << 1], t[N];
void kmp(char* t, int lent, int* f) {
int j = 0, k = 2;
while (j + 2 <= lent && t[j] == t[j + 1]) ++j;
f[2] = j; f[1] = lent;
rep(i, 3, lent) {
int p = k + f[k] - 1;
if (i + f[i - k + 1] - 1 < p) f[i] = f[i - k + 1];
else {
j = max(0, p - i + 1);
while (j + i <= lent && t[j] == t[i + j - 1]) ++j;
f[i] = j; k = i;
}
}
}
void ex_kmp(char* s, char* t, int lens, int lent, int* f, int* extend) {
int j = 0, k = 1;
while (j + 1 <= min(lens, lent) && s[j] == t[j]) ++j;
extend[1] = j;
rep(i, 2, lens) {
int p = k + extend[k] - 1;
if (i + f[i - k + 1] - 1 < p) extend[i] = f[i - k + 1];
else {
j = max(0, p - i + 1);
while (j + i <= lens && j + 1 <= lent && t[j] == s[i + j - 1]) ++j;
extend[i] = j; k = i;
}
}
}
int main() {
IOS; cin >> n;
rep(i, 0, n - 1) {
cin >> m >> s + a[i] + 1; a[i + 1] = a[i] + m + 1;
rep (j, a[i], a[i + 1] - 1) t[m + 1 - j + a[i]] = s[j];
kmp(t + 1, m, f + a[i]);
ex_kmp(s + a[i] + 1, t + 1, m, m, f + a[i], exta + a[i]);
kmp(s + a[i] + 1, m, f + a[i]);
ex_kmp(t + 1, s + a[i] + 1, m, m, f + a[i], extb + a[i]);
tr.insert(t + 1, extb + a[i], m);
}
ll ans = 0;
rep(i, 0, n - 1) ans += tr.query(s + a[i] + 1, exta + a[i], a[i + 1] - a[i] - 1);
cout << ans;
return 0;
}