Sample Input
3
1 a
2 ab
2 ba
Sample Output
5
Hint
The 5 palindromes are:
aa aba aba abba baab
题意
给定一系列字符串,询问它们能组成多少种回文(两两字符相连接,且满足自身与自身组合)
思路
对于字符串s和t,有三种情况它们可以组成一个回文;
1.|s| > |t|, t的反串作为s的前缀,s的后缀为一个回文
abcc
ba
2.|t| > |s|, s作为t反串的前缀, t的前缀为一个回文(保证翻转后该段不变)
ab
ccba
3.s == t, 且s, t本身就是回文
aa
aa
我们可以用manacher计算出每一个串在哪个节点之后均为回文,并将其插入到trie中在
对应的后缀为回文的节点上进行标记(后缀为回文就加一),最后当我们用一个字符串来
进行匹配时只需要加上这个字符串末尾位置的答案即可;
但是上面的操作我们只计算了第一种情况,我们可以在将所有串进行翻转后在插入到trie
中在进行一次上述操作,但是这一次我们会将第三中情况多算一遍,所以统计时需要减去这部分的
值
样例:
aa
ab
ccab
abcc
代码
#pragma GCC optimize(2)
// #include<unordered_map>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<string>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define Buff ios::sync_with_stdio(false)
#define rush() int Case = 0; int T; cin >> T; while(T--)
#define rep(i, a, b) for(int i = a; i <= b; i ++)
#define per(i, a, b) for(int i = a; i >= b; i --)
#define reps(i, a, b) for(int i = a; b; i ++)
#define clc(a, b) memset(a, b, sizeof(a))
#define Buff ios::sync_with_stdio(false)
#define readl(a) scanf("%lld", &a)
#define readd(a) scanf("%lf", &a)
#define readc(a) scanf("%c", &a)
#define reads(a) scanf("%s", a)
#define read(a) scanf("%d", &a)
#define lowbit(n) (n&(-n))
#define pb push_back
#define sqr(x) x*x
#define lson rt<<1
#define rson rt<<1|1
#define ls lson, l, mid
#define rs rson, mid+1, r
#define y second
#define x first
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int>PII;
const int mod = 1e9+7;
const double eps = 1e-6;
const int N = 2e6+7;
ll res;
char s[N], t[N<<1];
int f[N<<1], tr[N][26], val[N];
int start[N], w[N], idx, cnt[N], len[N];
void insert(char* str, int len)
{
int p = 0;
// for(int i = 1; i <= len; i ++)
// cout << str[i];puts("");
for(int i = 1; i <= len; i ++)
{
int t = str[i] - 'a';
if(!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
val[p] += w[i];
// printf("val[%d]: %d
", p, val[p]);
}
cnt[p] ++;
}
void manacher(char* str, int len)
{
int pos = 0, maxr = 0;
int n = 0;
t[n] = '%';
for(int i = 1; i <= len; i ++)
{
t[++n] = '#';
t[++n] = str[i];
w[i] = 0;
}
t[++n] = '#';
t[n+1] = 0;
// cout << t+1 <<endl;
for(int i = 1; i <= n; i ++)
{
f[i] = i < maxr ? min(f[pos*2-i], maxr-i):1;
while(t[i-f[i]] == t[i+f[i]]) f[i] ++;
if(maxr < i+f[i]) pos = i, maxr = i+f[i];
if(i+f[i] > n)
w[(i-f[i])/2] = 1;
}
insert(str, len);
}
int find(char* s, bool op)
{
int p = 0;
reps(i, 1, s[i])
{
int t = s[i] - 'a';
if(!tr[p][t]) return 0;
p = tr[p][t];
}
return op ? cnt[p] : val[p];
}
void search(int st, int en,bool op)
{
int n = 0;
per(i, en, st) t[++n] = s[i];
t[n+1] = 0;
// cout << t+1 <<endl;
// cout << "flag0" << endl;
// cout << find(t, 0) <<endl;
res += find(t, 0);
if(op)
{
res -= find(t, 1);
// cout << "flag1: " << endl;
// cout << find(t, 1) <<endl;
}
}
void init()
{
idx = 0;
clc(tr, 0);
clc(cnt, 0);
clc(val, 0);
}
void rever(int st, int en)
{
while(st <= en)
{
swap(s[st], s[en]);
st ++;
en --;
}
}
int main()
{
Buff;
int n, l = 1;
cin >> n;
rep(i, 1, n)
{
cin >> len[i] >> (s+l);
start[i] = l;
manacher(s+l-1, len[i]);
l += len[i];
}
rep(i, 1, n)
search(start[i], start[i]+len[i]-1, 0);
init();
rep(i, 1, n)
{
rever(start[i], start[i]+len[i]-1);
manacher(s+start[i]-1, len[i]);
}
rep(i, 1, n)
search(start[i], start[i]+len[i]-1, 1);
cout << res <<endl;
return 0;
}