poj 3376
tire + manacher
题意
http://poj.org/problem?id=3376
给(n)个串,两两连接,一共(n^2)种,求其中有多少是回文, 字符串的长度和小于(2e6)
Solution
在考虑第(i)个串(t)和多少个串(s)拼接是回文时,计算(s+t)是回文的数量,将所有数量累加就是最终的回文数。
首先考虑什么情况下,两个串拼接会是一个回文串。
case 1: s=reverse(t)
[egin{cases}
s = a_0dots a_{n-1} \
t = a_{n-1}dots a_0
end{cases}
]
case 2
[egin{cases}
egin{matrix}
&mathrm{palindrome} \
s = a_0dots a_i &overbrace{a_{i+1}dots a_{n-1}}
end{matrix} \ \
t = a_idots a_0
end{cases}
]
以及对称的情况
[egin{cases}
s = a_{n-1}dots a_i \
egin{matrix}
mathrm{palindrome} &\
t = overbrace{a_0dots a_{i-1}} &a_idots a_{n-1}
end{matrix}
end{cases}
]
因此可以对每个串,求出前缀和后缀是否是回文。把所有的串放到tire树上,统计答案。
用stl可能会TLE
#include <cstdio>
#include <stack>
#include <set>
#include <cmath>
#include <map>
#include <time.h>
#include <vector>
#include <iostream>
#include <string>
#include <cstring>
#include <algorithm>
//#include <memory.h>
#include <cstdlib>
#include <queue>
#include <iomanip>
#include <cassert>
// #include <unordered_map>
#define P pair<int, int>
#define LL long long
#define LD long double
#define PLL pair<LL, LL>
#define mset(a, b) memset(a, b, sizeof(a))
#define rep(i, a, b) for (int i = a; i < b; i++)
#define PI acos(-1.0)
#define random(x) rand() % x
#define debug(x) cout << #x << " " << x << "
"
using namespace std;
const int inf = 0x3f3f3f3f;
const LL __64inf = 0x3f3f3f3f3f3f3f3f;
#ifdef DEBUG
const int MAX = 2e6 + 50;
#else
const int MAX = 2e6 + 50;
#endif
const int mod = 998244353;
int str_len[MAX];
char str[MAX<<1];
char* str_p[MAX];
int sufix_palin[MAX];
int prefix_parlin[MAX];
int *sufix_palin_p[MAX];
int *prefix_parlin_p[MAX];
int radius[MAX];
int n;
int zero;
int palin;
struct Tire
{
int tot;
int nex[MAX][26];
int siz[MAX];
int _siz[MAX];
int rest_palin[MAX];
void init(){
tot = 1;
mset(nex[0], 0);
siz[0] = 0;
_siz[0] = 0;
rest_palin[0] = 0;
}
void add(const int idx, char *s){
int last = 0;
int rt = 0;
for(int i = 0; i < str_len[idx]; i++){
int c = s[i] - 'a';
if(!nex[rt][c] ){
mset(nex[tot], 0);
siz[tot] = 0;
_siz[tot] = 0;
rest_palin[tot] = 0;
nex[rt][c] = tot ++;
}
last = rt;
rt = nex[rt][c];
if(i + 2 == str_len[idx]) _siz[rt]++;
if(sufix_palin_p[idx][i]){
rest_palin[last]++;
}
}
siz[rt]++;
}
int* operator [] (int idx){
return nex[idx];
}
}tire;
string manacherStr(int len, char *s){
string res = "#";
for(int i = 0; i < len; i++){
res += s[i];
res += '#';
}
return res;
}
void solve(const int idx, char *s){
if(str_len[idx] == 0) return ;
int R = -1;
int C = -1;
int Max = -1;
string str = manacherStr(str_len[idx], s);
for(int i = 0; i < str.size(); i++){
radius[i] = R > i ? min(radius[2*C-i], R-i+1) : 1;
while(i + radius[i] < str.size() and i - radius[i] > -1) {
if(str[i-radius[i]] == str[i+radius[i]])
radius[i]++;
else break;
}
if(i + radius[i] > R) {
R = i + radius[i]-1;
C = i;
}
Max = max(Max, radius[i]);
}
if(Max - 1 == str_len[idx]) palin++;
for(int i = 1; i+1 < str.size(); i++){
if(str[i] == '#') {
int len = radius[i]-1;
int r = len / 2;
int start_idx = (i-2) / 2 - r + 1;
if(start_idx + len - 1 == str_len[idx]-1)
sufix_palin_p[idx][start_idx] = 1;
if(start_idx == 0)
prefix_parlin_p[idx][start_idx+len-1] = 1;
}
else {
int len = radius[i]-1;
int r = len / 2;
int start_idx = (i-1)/2 - r;
if(start_idx + len - 1 == str_len[idx]-1)
sufix_palin_p[idx][start_idx] = 1;
if(start_idx == 0)
prefix_parlin_p[idx][start_idx+len-1] = 1;
}
}
}
LL calcu(const int idx, char* s){
int rt = 0;
LL res = 0;
for(int i = str_len[idx]-1; i >= 0; i--){
int c = s[i] - 'a';
rt = tire[rt][c] ;
if(!rt) break;
if(i == 0){
res += (LL)tire.siz[rt];
res += (LL)tire.rest_palin[rt];
}
else if(i >= 1 and prefix_parlin_p[idx][i-1]){
res += (LL)tire.siz[rt];
}
}
return res;
}
int main(){
#ifdef DEBUG
freopen("in", "r", stdin);
#endif
tire.init();
scanf("%d", &n);
int cur = 0;
for(int i = 0; i < n; i++) {
int x;
scanf("%d%s", &x, &str[cur]);
zero += x == 0;
str_p[i] = str+cur;
str_len[i] = x;
prefix_parlin_p[i] = prefix_parlin+cur;
sufix_palin_p[i] = sufix_palin +cur;
cur += x;
}
for(int i = 0; i < n; i++){
solve(i, str_p[i]);
}
for(int i = 0; i < n; i++)
tire.add(i, str_p[i]);
LL ans = 0;
tire.rest_palin[0] = tire.siz[0] = 0;
for(int i = 0; i < n; i++){
#ifdef DEBUG
printf("debug i:%d calcu:%d
", i, calcu(i, str_p[i]));
#endif
ans += 1LL * calcu(i, str_p[i]);
}
ans += 1LL * zero * zero;
ans += 1LL * zero * palin * 2;
#ifdef DEBUG
for(int i = 0; i < n; i++){
for(int j = 0; j < str_len[i]; j++)
cout << prefix_parlin_p[i][j] << " ";
puts("");
for(int j = 0; j < str_len[i]; j++)
cout << sufix_palin_p[i][j] << " " ;
puts("
");
}
#endif
printf("%lld
", ans);
return 0;
}